| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598 |
- from rag.vector_db.milvus_vector import HybridRetriever
- from response_info import generate_message, generate_response
- from utils.get_logger import setup_logger
- from datetime import datetime
- from uuid import uuid1
- import mysql.connector
- from mysql.connector import pooling, Error, errors
- import threading
- from concurrent.futures import ThreadPoolExecutor, TimeoutError
- from config import milvus_uri, mysql_config
- logger = setup_logger(__name__)
- # uri = "http://localhost:19530"
- # if 'POOL' not in globals():
- # try:
- # POOL = pooling.MySQLConnectionPool(
- # pool_name="mysql_pool",
- # pool_size=10,
- # **mysql_config
- # )
- # logger.info("MySQL 连接池初始化成功")
- # except Error as e:
- # logger.info(f"初始化 MySQL 连接池失败: {e}")
- # POOL = None
- class MilvusOperate:
- def __init__(self, collection_name: str = "default", embedding_name:str = "e5"):
- self.collection = collection_name
- self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name)
- self.mysql_client = MysqlOperate()
- def _has_collection(self):
- is_collection = self.hybrid_retriever.has_collection()
- return is_collection
-
- def _create_collection(self):
- if self._has_collection():
- resp = {"code": 400, "message": "数据库已存在"}
- else:
- create_result = self.hybrid_retriever.build_collection()
- resp = generate_message(create_result)
- return resp
-
- def _delete_collection(self):
- delete_result = self.hybrid_retriever.delete_collection(self.collection)
- resp = generate_message(delete_result)
- return resp
-
- def _put_by_id(self, slice_json):
- slice_id = slice_json.get("slice_id", None)
- slice_text = slice_json.get("slice_text", None)
- update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text)
- if update_result.endswith("success"):
- # 如果成功,更新mysql中知识库总长度和文档长度
- update_json = {}
- update_json["knowledge_id"] = slice_json.get("knowledge_id")
- update_json["doc_id"] = slice_json.get("document_id")
- update_json["chunk_len"] = chunk_len
- update_json["operate"] = "update"
- update_json["chunk_id"] = slice_id
- update_json["chunk_text"] = slice_text
- update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
- else:
- update_flag = False
-
- if not update_flag:
- update_result = "update_error"
-
- resp = generate_message(update_result)
- return resp
-
- def _insert_slice(self, slice_json):
- slice_id = str(uuid1())
- knowledge_id = slice_json.get("knowledge_id")
- doc_id = slice_json.get("document_id")
- slice_text = slice_json.get("slice_text", None)
- doc_name = slice_json.get("doc_name")
- chunk_len = len(slice_text)
- metadata = {
- "content": slice_text,
- "doc_id": doc_id,
- "chunk_id": slice_id,
- "metadata": {"source": doc_name, "chunk_len": chunk_len},
- "Chapter": slice_json.get("Chapter",""),
- "Father_Chapter": slice_json.get("Father_Chapter",""),
- }
- insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata)
- if insert_flag:
- # 如果成功,更新mysql中知识库总长度和文档长度
- update_json = {}
- update_json["knowledge_id"] = slice_json.get("knowledge_id")
- update_json["doc_id"] = slice_json.get("document_id")
- update_json["chunk_len"] = chunk_len
- update_json["operate"] = "insert"
- update_json["chunk_id"] = slice_id
- update_json["chunk_text"] = slice_text
- update_json["slice_index"] = slice_json.get("slice_index")
- update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
- else:
- logger.error(f"插入向量库出错:{insert_str}")
- update_flag = False
- update_str = "向量库写入出错"
- # pass
-
- if not update_flag:
- logger.error(f"新增切片中mysql数据库出错:{update_str}")
- # insert_result = "insert_error"
- success = "insert_error"
- else:
- # insert_result = "insert_success"
- success = "insert_success"
-
- # resp = generate_message(insert_result)
- resp = {"status": success, "slice_id":slice_id}
- return resp
-
- def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id):
- logger.info(f"删除的切片id:{chunk_id}")
- delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id)
- if delete_result.endswith("success"):
- chunk_len = delete_chunk_len[0]
- update_json = {
- "knowledge_id": knowledge_id,
- "doc_id": document_id,
- "chunk_len": -chunk_len,
- "operate": "delete",
- "chunk_id": chunk_id
- }
- update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
- else:
- logger.error("根据chunk id删除向量库失败")
- update_flag = False
- update_str = "根据chunk id删除失败"
-
- if not update_flag:
- logger.error(update_str)
- delete_result = "delete_error"
-
- resp = generate_message(delete_result)
- return resp
-
- def _delete_by_doc_id(self, doc_id: str = None):
- logger.info(f"删除数据的id:{doc_id}")
- delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id)
- resp = generate_message(delete_result)
- return resp
-
- def _search_by_chunk_id(self, chunk_id):
- if self._has_collection():
- query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id)
- else:
- query_result = []
- logger.info(f"根据切片查询到的信息:{query_result}")
- resp = generate_response(query_result)
- return resp
-
- def _search_by_chunk_id_list(self, chunk_id_list):
- if self._has_collection():
- query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list)
- else:
- query_result = []
- logger.info(f"召回的切片列表查询切片信息:{query_result}")
-
- chunk_content_list = []
- for chunk_dict in query_result:
- chunk_content = chunk_dict.get("content")
- chunk_content_list.append(chunk_content)
- return chunk_content_list
-
-
- def _search_by_key_word(self, search_json):
- if self._has_collection():
- doc_id = search_json.get("document_id", None)
- text = search_json.get("text", None)
- page_num = search_json.get("pageNum", 1)
- page_size = search_json.get("pageSize", 10)
- page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库
- query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text)
- else:
- query_result = []
- resp = generate_response(query_result,page_num,page_size)
- return resp
-
- def _insert_data(self, docs):
- insert_flag = ""
- insert_info = ""
- for doc in docs:
- chunk = doc.get("content")
- insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc)
- if not insert_flag:
- break
- resp = insert_flag if insert_flag else "insert_error"
- return resp, insert_info
-
- def _batch_insert_data(self, docs, text_lists):
- insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs)
- resp = insert_flag
- return resp, insert_info
- def _search(self, query, k, mode):
- search_result = self.hybrid_retriever.search(query, k, mode)
- return search_result
-
- def _query_by_scalar_field(self, doc_id: str, field_name: str, field_value: str):
- """
- 根据标量字段查询数据
-
- 参数:
- doc_id: 文档ID
- field_name: 字段名(如 Father_Chapter)
- field_value: 字段值
- 返回:
- 查询结果列表
- """
- return self.hybrid_retriever.query_by_scalar_field(doc_id, field_name, field_value)
-
- def _copy_docs_to_new_collection(self, new_collection_name, doc_ids, embedding_name="e5"):
- """
- 将指定的文档数据复制到新集合或现有集合
- 使用雪花算法生成新的 doc_id 和 chunk_id
-
- 参数:
- new_collection_name: 目标集合名称(也是新的 knowledge_id)
- doc_ids: 要复制的文档ID列表
- embedding_name: 向量模型名称
-
- 返回:
- 响应字典,包含 doc_id_mapping 映射关系
- """
- try:
- # 1. 从源集合查询数据
- logger.info(f"从集合 {self.collection} 查询文档: {doc_ids}")
- query_results = self.hybrid_retriever.query_by_doc_ids(doc_ids)
-
- if not query_results:
- return {"code": 404, "message": "未找到匹配的文档数据", "doc_id_mapping": {}, "chunk_id_mapping": {}}
-
- logger.info(f"查询到 {len(query_results)} 条数据")
-
- # 2. 检查目标集合是否存在,不存在则创建
- target_milvus_client = MilvusOperate(collection_name=new_collection_name, embedding_name=embedding_name)
-
- collection_exists = target_milvus_client._has_collection()
- if not collection_exists:
- logger.info(f"创建新集合: {new_collection_name}")
- create_result = target_milvus_client._create_collection()
- if create_result.get("code") != 200:
- create_result["doc_id_mapping"] = {}
- return create_result
- else:
- logger.info(f"集合 {new_collection_name} 已存在,直接插入数据")
-
- # 3. 为每个源 doc_id 生成新的 doc_id(使用雪花算法)
- doc_id_mapping = {} # {old_doc_id: new_doc_id}
- for old_doc_id in doc_ids:
- doc_id_mapping[old_doc_id] = generate_snowflake_id()
-
- # 4. 准备插入数据(移除pk字段,使用新的 doc_id 和 chunk_id)
- insert_data = []
- chunk_id_mapping = {} # {old_chunk_id: new_chunk_id}
- for item in query_results:
- old_doc_id = item.get("doc_id")
- old_chunk_id = item.get("chunk_id")
- new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id())
- new_chunk_id = generate_snowflake_id() # 每个切片生成新的 chunk_id
- chunk_id_mapping[old_chunk_id] = new_chunk_id # 记录映射
-
- new_item = {
- "content": item.get("content"),
- "dense_vector": item.get("dense_vector"),
- "doc_id": new_doc_id,
- "chunk_id": new_chunk_id,
- "Father_Chapter": item.get("Father_Chapter"),
- "Chapter": item.get("Chapter"),
- "metadata": item.get("metadata")
- }
- insert_data.append(new_item)
-
- # 5. 批量插入数据到目标集合
- logger.info(f"开始向集合 {new_collection_name} 插入 {len(insert_data)} 条数据")
- try:
- target_milvus_client.hybrid_retriever.client.insert(
- collection_name=new_collection_name,
- data=insert_data
- )
- logger.info(f"成功向集合插入数据")
-
- return {
- "code": 200,
- "message": "复制成功",
- "data": {
- "source_collection": self.collection,
- "target_collection": new_collection_name,
- "doc_ids": doc_ids,
- "total_records": len(insert_data),
- "collection_existed": collection_exists
- },
- "doc_id_mapping": doc_id_mapping,
- "chunk_id_mapping": chunk_id_mapping
- }
- except Exception as e:
- logger.error(f"插入数据到集合失败: {e}")
- # 插入失败且是新创建的集合时删除
- if not collection_exists:
- target_milvus_client._delete_collection()
- return {"code": 500, "message": f"插入数据失败: {str(e)}", "doc_id_mapping": {}, "chunk_id_mapping": {}}
-
- except Exception as e:
- logger.error(f"复制文档到新集合失败: {e}")
- return {"code": 500, "message": f"复制失败: {str(e)}", "doc_id_mapping": {}, "chunk_id_mapping": {}}
- # class MysqlOperate:
- # def get_connection(self):
- # """
- # 从连接池中获取一个连接
- # :return: 数据库连接对象
- # """
- # # try:
- # # with ThreadPoolExecutor() as executor:
- # # future = executor.submit(POOL.get_connection)
- # # connection = future.result(timeout=5.0) # 设置超时时间为5秒
- # # logger.info("成功从连接池获取连接")
- # # return connection, "success"
- # # except TimeoutError:
- # # logger.error("获取mysql数据库连接池超时")
- # # return None, "mysql获取连接池超时"
- # # except errors.InterfaceError as e:
- # # logger.error(f"MySQL 接口异常:{e}")
- # # return None, "mysql接口异常"
- # # except errors.OperationalError as e:
- # # logger.error(f"MySQL 操作错误:{e}")
- # # return None, "mysql 操作错误"
- # # except Error as e:
- # # logger.error(f"无法从连接池获取连接: {e}")
- # # return None, str(e)
- # connection = None
- # event = threading.Event()
- # def target():
- # nonlocal connection
- # try:
- # connection = POOL.get_connection()
- # finally:
- # event.set()
- # thread = threading.Thread(target=target)
- # thread.start()
- # event.wait(timeout=5)
- # if thread.is_alive():
- # # 超时处理
- # logger.error("获取连接超时")
- # return None, "获取连接超时"
- # else:
- # if connection:
- # return connection, "success"
- # else:
- # logger.error("获取连接失败")
- # return None, "获取连接失败"
- # def insert_to_slice(self, docs, knowledge_id, doc_id):
- # """
- # 插入数据到切片信息表中 slice_info
- # """
- # connection = None
- # cursor = None
- # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- # values = []
- # connection, cennction_info = self.get_connection()
- # if not connection:
- # return False, cennction_info
-
- # for chunk in docs:
- # slice_id = chunk.get("chunk_id")
- # slice_text = chunk.get("content")
- # chunk_index = chunk.get("metadata").get("chunk_index")
- # values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index))
- # try:
- # cursor = connection.cursor()
- # # insert_sql = """
- # # INSERT INTO slice_info (
- # # slice_id,
- # # knowledge_id,
- # # document_id,
- # # slice_text,
- # # create_time,
- # # slice_index
- # # ) VALUES (%s, %s, %s, %s, %s,%s)
- # # """
-
- # # 容错“for key 'UK_ID_TYPE_KEY'”
- # insert_sql = """
- # INSERT INTO slice_info (
- # slice_id,
- # knowledge_id,
- # document_id,
- # slice_text,
- # create_time,
- # slice_index
- # ) VALUES (%s, %s, %s, %s, %s, %s)
- # ON DUPLICATE KEY UPDATE
- # slice_text = VALUES(slice_text),
- # create_time = VALUES(create_time),
- # slice_index = VALUES(slice_index)
- # """
-
- # cursor.executemany(insert_sql, values)
- # connection.commit()
- # logger.info(f"批量插入切片数据成功。")
- # return True, "success"
- # except Error as e:
- # logger.error(f"数据库操作出错:{e}")
- # connection.rollback()
- # return False, str(e)
- # finally:
- # # if cursor:
- # cursor.close()
- # # if connection and connection.is_connected():
- # connection.close()
- # def delete_to_slice(self, doc_id):
- # """
- # 删除 slice_info库中切片信息
- # """
- # connection = None
- # cursor = None
- # connection, connection_info = self.get_connection()
- # if not connection:
- # return False, connection_info
- # try:
- # cursor = connection.cursor()
- # delete_sql = f"DELETE FROM slice_info WHERE document_id = %s"
- # cursor.execute(delete_sql, (doc_id,))
- # connection.commit()
- # logger.info(f"删除数据成功")
- # return True, "success"
- # except Error as e:
- # logger.error(f"根据{doc_id}删除数据失败:{e}")
- # connection.rollback()
- # return False, str(e)
- # finally:
- # # if cursor:
- # cursor.close()
- # # if connection and connection.is_connected():
- # connection.close()
- # def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
- # """
- # 批量插入数据到指定表
- # """
- # connection = None
- # cursor = None
- # connection, connection_info = self.get_connection()
- # if not connection:
- # return False, connection_info
-
- # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- # values = []
- # for img_key, img_value in image_dict.items():
- # origin_text = img_key
- # media_url = img_value
- # values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now))
- # try:
- # cursor = connection.cursor()
- # # insert_sql = """
- # # INSERT INTO bm_media_replacement (
- # # knowledge_id,
- # # document_id,
- # # origin_text,
- # # media_type,
- # # media_url,
- # # create_time
- # # ) VALUES (%s, %s, %s, %s, %s, %s)
- # # """
-
- # # 容错“for key 'UK_ID_TYPE_KEY'”
- # insert_sql = """
- # INSERT INTO bm_media_replacement (
- # knowledge_id,
- # document_id,
- # origin_text,
- # media_type,
- # media_url,
- # create_time
- # ) VALUES (%s, %s, %s, %s, %s, %s)
- # ON DUPLICATE KEY UPDATE
- # origin_text = VALUES(origin_text),
- # media_type = VALUES(media_type),
- # media_url = VALUES(media_url),
- # create_time = VALUES(create_time)
- # """
- # cursor.executemany(insert_sql, values)
- # connection.commit()
- # logger.info(f"插入到bm_media_replacement表成功")
- # return True, "success"
- # except Error as e:
- # logger.error(f"数据库操作出错:{e}")
- # connection.rollback()
- # return False, str(e)
- # finally:
- # # if cursor:
- # cursor.close()
- # # if connection and connection.is_connected():
- # connection.close()
- # def delete_image_url(self, doc_id):
- # """
- # 根据doc id删除bm_media_replacement中的数据
- # """
- # connection = None
- # cursor = None
- # connection, connection_info = self.get_connection()
- # if not connection:
- # return False, connection_info
-
- # try:
- # cursor = connection.cursor()
- # delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s"
- # cursor.execute(delete_sql, (doc_id,))
- # connection.commit()
- # logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功")
- # return True, "success"
- # except Error as e:
- # logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}")
- # connection.rollback()
- # return False, str(e)
- # finally:
- # # if cursor:
- # cursor.close()
- # # if connection and connection.is_connected():
- # connection.close()
- # def update_total_doc_len(self, update_json):
- # """
- # 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息
- # """
- # knowledge_id = update_json.get("knowledge_id")
- # doc_id = update_json.get("doc_id")
- # chunk_len = update_json.get("chunk_len")
- # operate = update_json.get("operate")
- # chunk_id = update_json.get("chunk_id")
- # chunk_text = update_json.get("chunk_text")
- # connection = None
- # cursor = None
- # connection, connection_info = self.get_connection()
- # if not connection:
- # return False, connection_info
- # try:
- # cursor = connection.cursor()
- # query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s"
- # query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s"
- # cursor.execute(query_doc_word_num_sql, (doc_id,))
- # doc_result = cursor.fetchone()
- # logger.info(f"查询到的文档长度信息:{doc_result}")
- # cursor.execute(query_knowledge_word_num_sql, (knowledge_id, ))
- # knowledge_result = cursor.fetchone()
- # logger.info(f"查询到的知识库总长度信息:{knowledge_result}")
- # if not doc_result:
- # new_word_num = 0
- # slice_total = 0
- # else:
- # old_word_num = doc_result[0]
- # slice_total = doc_result[1]
- # new_word_num = old_word_num + chunk_len
- # slice_total -= 1 if slice_total else 0
- # if not knowledge_result:
- # new_knowledge_word_num = 0
- # else:
- # old_knowledge_word_num = knowledge_result[0]
- # new_knowledge_word_num = old_knowledge_word_num + chunk_len
- # if operate == "update":
- # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
- # cursor.execute(update_sql, (new_word_num, doc_id))
- # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- # update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s"
- # cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id))
- # elif operate == "insert":
- # query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s"
- # cursor.execute(query_slice_info_index_sql, (doc_id,))
- # chunk_index_result = cursor.fetchone()[0]
- # # logger.info(chunk_index_result)
- # if chunk_index_result:
- # chunk_max_index = int(chunk_index_result)
- # else:
- # chunk_max_index = 0
- # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
- # cursor.execute(update_sql, (new_word_num, doc_id))
- # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- # insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)"
- # cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1))
- # else:
- # update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s"
- # cursor.execute(update_sql, (new_word_num, slice_total, doc_id))
- # # 删除切片id对应的切片
- # delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s"
- # cursor.execute(delete_slice_sql, (chunk_id, ))
- # update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s"
- # cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id))
- # connection.commit()
- # logger.info("bm_document和bm_knowledge数据更新成功")
- # return True, "success"
- # except Error as e:
- # logger.error(f"数据库操作出错:{e}")
- # connection.rollback()
- # return False, str(e)
- # finally:
- # # if cursor:
- # cursor.close()
- # # if connection and connection.is_connected():
- # connection.close()
- TABLE_NAME = "bm_document"
- STATUS_FIELD = "update_by"
- TASK_ID_FIELD = "document_id"
- USER_ID_FIELD = "remark"
- import time
- # ========== 全局初始化连接池(自动检测 + 超时保护) ==========
- class SafeMySQLPool:
- def __init__(self, pool_size=10, conn_timeout=10, idle_timeout=60, **mysql_config):
- mysql_config.setdefault("connect_timeout", conn_timeout)
- mysql_config.setdefault("pool_reset_session", True)
- self._pool = pooling.MySQLConnectionPool(
- pool_name="safe_mysql_pool",
- pool_size=pool_size,
- **mysql_config
- )
- # 使用 RLock 可重入锁,避免 _auto_reclaimer 调用 close 时死锁
- self._lock = threading.RLock()
- self._active_conns = {} # {id(conn): (conn, last_used_time)}
- self._idle_timeout = idle_timeout
- self._stop_event = threading.Event()
- threading.Thread(target=self._auto_reclaimer, daemon=True).start()
- def get_connection(self, timeout=10):
- """安全获取连接(带超时检测与追踪)"""
- start = time.time()
- while True:
- try:
- conn = self._pool.get_connection()
- conn.ping(reconnect=True, attempts=3, delay=2)
- with self._lock:
- self._active_conns[id(conn)] = (conn, time.time())
- return self._wrap_connection(conn)
- except errors.PoolError:
- if time.time() - start > timeout:
- raise TimeoutError(f"获取 MySQL 连接超时(超过 {timeout}s)")
- time.sleep(0.3)
- def _wrap_connection(self, conn):
- """包装连接对象以监控关闭事件"""
- pool = self
- orig_close = conn.close
- def safe_close():
- try:
- orig_close()
- finally:
- with pool._lock:
- pool._active_conns.pop(id(conn), None)
- conn.close = safe_close
- return conn
- def _auto_reclaimer(self):
- """后台线程自动回收超时未关闭连接"""
- while not self._stop_event.is_set():
- time.sleep(5)
- now = time.time()
- with self._lock:
- to_remove = []
- for cid, (conn, last_used) in list(self._active_conns.items()):
- if now - last_used > self._idle_timeout:
- try:
- conn.close()
- logger.warning(f"[回收] 已回收超时未关闭连接 (idle={int(now - last_used)}s)")
- except Exception as e:
- logger.error(f"[回收] 回收连接失败: {e}")
- to_remove.append(cid)
- for cid in to_remove:
- self._active_conns.pop(cid, None)
- def close_all(self):
- """停止守护线程并关闭所有连接"""
- self._stop_event.set()
- with self._lock:
- for conn, _ in self._active_conns.values():
- try:
- conn.close()
- except:
- pass
- self._active_conns.clear()
- # ========== 初始化连接池 ==========
- if "POOL" not in globals():
- try:
- POOL = SafeMySQLPool(pool_size=10, idle_timeout=60, **mysql_config)
- logger.info("MySQL 连接池初始化成功")
- except Error as e:
- logger.error(f"MySQL 连接池初始化失败: {e}")
- POOL = None
- """
- 雪花算法获取唯一id
- """
- import time
- import threading
- class Snowflake:
- def __init__(self, datacenter_id=1, worker_id=1):
- self.worker_id_bits = 5
- self.datacenter_id_bits = 5
- self.sequence_bits = 12
- self.max_worker_id = -1 ^ (-1 << self.worker_id_bits)
- self.max_datacenter_id = -1 ^ (-1 << self.datacenter_id_bits)
- self.worker_id = worker_id
- self.datacenter_id = datacenter_id
- self.sequence = 0
- self.worker_id_shift = self.sequence_bits
- self.datacenter_id_shift = self.sequence_bits + self.worker_id_bits
- self.timestamp_left_shift = self.sequence_bits + self.worker_id_bits + self.datacenter_id_bits
- self.twepoch = 1288834974657
- self.last_timestamp = -1
- self.lock = threading.Lock()
- def _timestamp(self):
- return int(time.time() * 1000)
- def _wait_next_ms(self, last):
- ts = self._timestamp()
- while ts <= last:
- ts = self._timestamp()
- return ts
- def generate_id(self):
- with self.lock:
- timestamp = self._timestamp()
- if timestamp < self.last_timestamp:
- raise Exception("Clock moved backwards, refusing to generate id")
- if timestamp == self.last_timestamp:
- self.sequence = (self.sequence + 1) & ((1 << self.sequence_bits) - 1)
- if self.sequence == 0:
- timestamp = self._wait_next_ms(timestamp)
- else:
- self.sequence = 0
- self.last_timestamp = timestamp
- snowflake_id = (
- ((timestamp - self.twepoch) << self.timestamp_left_shift) |
- (self.datacenter_id << self.datacenter_id_shift) |
- (self.worker_id << self.worker_id_shift) |
- self.sequence
- )
- # 转成固定 20 位数字
- return str(snowflake_id).zfill(20)
- # ========== 全局雪花算法 ID 生成器 ==========
- snowflake_id_generator = Snowflake(datacenter_id=1, worker_id=1)
- def generate_snowflake_id():
- """
- 全局方法:使用雪花算法生成唯一ID
- 返回20位数字字符串
- """
- return snowflake_id_generator.generate_id()
- # ========== MysqlOperate 类 ==========
- class MysqlOperate:
- def get_connection(self):
- """安全获取连接"""
- if not POOL:
- return None, "连接池未初始化"
- try:
- connection = POOL.get_connection(timeout=5)
- return connection, "success"
- except TimeoutError as e:
- logger.error(str(e))
- return None, "获取连接超时"
- except Error as e:
- logger.error(f"MySQL 获取连接失败: {e}")
- return None, str(e)
- def _execute_many(self, sql, values, success_msg, err_msg):
- """通用批量执行模板"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- cursor.executemany(sql, values)
- connection.commit()
- logger.info(success_msg)
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"{err_msg}: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def insert_to_slice(self, docs, knowledge_id, doc_id):
- """批量插入切片信息(同时存储 slice_text 和 old_slice_text)"""
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- values = [
- (
- chunk.get("chunk_id"),
- knowledge_id,
- doc_id,
- chunk.get("content"),
- chunk.get("content"), # old_slice_text 存储原始文本
- date_now,
- chunk.get("metadata", {}).get("chunk_index"),
- chunk.get("Chapter"),
- chunk.get("Father_Chapter", "")
- )
- for chunk in docs
- ]
-
- sql = """
- INSERT INTO slice_info (
- slice_id, knowledge_id, document_id, slice_text, old_slice_text, create_time, slice_index, section, parent_section
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
- ON DUPLICATE KEY UPDATE
- slice_text = VALUES(slice_text),
- create_time = VALUES(create_time),
- slice_index = VALUES(slice_index),
- section = VALUES(section),
- parent_section = VALUES(parent_section)
- """
- return self._execute_many(sql, values, "批量插入切片数据成功", "插入 slice_info 出错")
- def delete_to_slice(self, doc_id):
- """删除切片"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- cursor.execute("DELETE FROM slice_info WHERE document_id = %s", (doc_id,))
- connection.commit()
- logger.info(f"删除 slice_info 数据成功")
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"删除 slice_info 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
- """插入图片映射表"""
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- values = [
- (knowledge_id, doc_id, k, "image", v, date_now)
- for k, v in image_dict.items()
- ]
- sql = """
- INSERT INTO bm_media_replacement (
- knowledge_id, document_id, origin_text, media_type, media_url, create_time
- ) VALUES (%s, %s, %s, %s, %s, %s)
- ON DUPLICATE KEY UPDATE
- origin_text = VALUES(origin_text),
- media_type = VALUES(media_type),
- media_url = VALUES(media_url),
- create_time = VALUES(create_time)
- """
- return self._execute_many(sql, values, "插入 bm_media_replacement 成功", "插入 bm_media_replacement 出错")
- def delete_image_url(self, doc_id):
- """删除图片映射"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- cursor.execute("DELETE FROM bm_media_replacement WHERE document_id = %s", (doc_id,))
- connection.commit()
- logger.info(f"删除 bm_media_replacement 成功")
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"删除 bm_media_replacement 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def update_task_status_start(self, task_id):
- """更新任务状态为开始(1)"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- sql = f"UPDATE {TABLE_NAME} SET {STATUS_FIELD} = %s WHERE {TASK_ID_FIELD} = %s"
- cursor.execute(sql, (1, task_id))
- connection.commit()
- logger.info(f"任务 {task_id} 状态更新为开始(1)")
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"更新任务状态为开始失败: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def update_task_status_complete(self, task_id, user_id):
- """更新任务状态为完成(2)并更新用户ID"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- sql = f"UPDATE {TABLE_NAME} SET {STATUS_FIELD} = %s, {USER_ID_FIELD} = %s WHERE {TASK_ID_FIELD} = %s"
- cursor.execute(sql, (2, user_id, task_id))
- connection.commit()
- logger.info(f"任务 {task_id} 状态更新为完成(2),用户ID: {user_id}")
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"更新任务状态为完成失败: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def update_task_status_error(self, task_id):
- """更新任务状态为错误(0)"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- sql = f"UPDATE {TABLE_NAME} SET {STATUS_FIELD} = %s WHERE {TASK_ID_FIELD} = %s"
- cursor.execute(sql, (0, task_id))
- connection.commit()
- logger.info(f"任务 {task_id} 状态更新为错误(0)")
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"更新任务状态为错误失败: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def delete_document(self, task_id):
- """删除 bm_document 表中的记录(取消任务前清理)"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor()
- sql = f"DELETE FROM {TABLE_NAME} WHERE {TASK_ID_FIELD} = %s"
- cursor.execute(sql, (task_id,))
- affected_rows = cursor.rowcount
- connection.commit()
- logger.info(f"删除 bm_document 记录成功: task_id={task_id}, 影响行数={affected_rows}")
- return True, affected_rows
- except Error as e:
- connection.rollback()
- logger.error(f"删除 bm_document 记录失败: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def query_parent_generation_enabled(self, doc_ids: list):
- """
- 查询文档的 parent_generation_enabled 字段
- 返回 parent_generation_enabled=1 的 doc_id 集合
- """
- if not doc_ids:
- return set()
- connection, info = self.get_connection()
- if not connection:
- return set()
- cursor = None
- try:
- cursor = connection.cursor()
- placeholders = ", ".join(["%s"] * len(doc_ids))
- sql = f"SELECT document_id FROM bm_document WHERE document_id IN ({placeholders}) AND parent_generation_enabled = 1"
- cursor.execute(sql, tuple(doc_ids))
- results = cursor.fetchall()
- return {row[0] for row in results}
- except Error as e:
- logger.error(f"查询 parent_generation_enabled 失败: {e}")
- return set()
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def insert_oss_record(self, file_name, url, tenant_id="000000", file_extension=""):
- """插入 OSS 记录到 sys_oss 表"""
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- # 生成类似 "a2922173479520702464" 的 oss_id 20位
- # import random
- # oss_id = f"a{int(time.time() * 1000)}{random.randint(1000, 9999)}"
- oss_id = Snowflake().generate_id()
- create_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
-
- cursor = connection.cursor()
- sql = """
- INSERT INTO sys_oss (oss_id, tenant_id, file_name, original_name, url, create_time, file_suffix)
- VALUES (%s, %s, %s, %s, %s, %s, %s)
- """
- cursor.execute(sql, (oss_id, tenant_id, file_name, file_name, url, create_time, file_extension))
- connection.commit()
- logger.info(f"OSS 记录插入成功: {oss_id}")
- return True, oss_id
- except Error as e:
- connection.rollback()
- logger.error(f"插入 OSS 记录失败: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def update_total_doc_len(self, update_json):
- """
- 更新长度表和文档长度表,删除/更新/插入 slice_info 表
- """
- knowledge_id = update_json.get("knowledge_id")
- doc_id = update_json.get("doc_id")
- chunk_len = update_json.get("chunk_len")
- operate = update_json.get("operate")
- chunk_id = update_json.get("chunk_id")
- chunk_text = update_json.get("chunk_text", "")
- slice_index = update_json.get("slice_index")
-
- connection, info = self.get_connection()
- if not connection:
- return False, info
-
- cursor = None
- try:
- cursor = connection.cursor()
-
- # 查询文档当前信息
- cursor.execute(
- "SELECT word_num, slice_total FROM bm_document WHERE document_id = %s",
- (doc_id,)
- )
- doc_result = cursor.fetchone()
- logger.info(f"查询到的文档长度信息:{doc_result}")
-
- # 查询知识库当前信息
- cursor.execute(
- "SELECT word_num FROM bm_knowledge WHERE knowledge_id = %s",
- (knowledge_id,)
- )
- knowledge_result = cursor.fetchone()
- logger.info(f"查询到的知识库总长度信息:{knowledge_result}")
-
- # 计算新的文档长度
- if not doc_result:
- new_word_num = chunk_len if chunk_len > 0 else 0
- slice_total = 0
- else:
- old_word_num = doc_result[0] or 0
- slice_total = doc_result[1] or 0
- new_word_num = old_word_num + chunk_len
- if operate == "delete":
- slice_total = max(0, slice_total - 1)
-
- # 计算新的知识库长度
- if not knowledge_result:
- new_knowledge_word_num = chunk_len if chunk_len > 0 else 0
- else:
- old_knowledge_word_num = knowledge_result[0] or 0
- new_knowledge_word_num = old_knowledge_word_num + chunk_len
-
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
-
- # 根据操作类型执行不同的 SQL
- if operate == "update":
- # # 更新文档长度
- # cursor.execute(
- # "UPDATE bm_document SET word_num = %s WHERE document_id = %s",
- # (new_word_num, doc_id)
- # )
- # 更新切片内容
- cursor.execute(
- "UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s",
- (chunk_text, date_now, chunk_id)
- )
-
- elif operate == "insert":
- # 获取最大切片索引
- # cursor.execute(
- # "SELECT MAX(slice_index) FROM slice_info WHERE document_id = %s",
- # (doc_id,)
- # )
- # chunk_index_result = cursor.fetchone()[0]
- # chunk_max_index = int(chunk_index_result) if chunk_index_result else 0
-
- # # 更新文档长度
- # cursor.execute(
- # "UPDATE bm_document SET word_num = %s WHERE document_id = %s",
- # (new_word_num, doc_id)
- # )
- # 插入新切片
- cursor.execute(
- """INSERT INTO slice_info
- (slice_id, knowledge_id, document_id, slice_text, create_time, slice_index, old_slice_text)
- VALUES (%s, %s, %s, %s, %s, %s, %s)""",
- (chunk_id, knowledge_id, doc_id, chunk_text, date_now, int(slice_index), chunk_text)
- )
-
- elif operate == "delete":
- # # 更新文档长度和切片总数
- # cursor.execute(
- # "UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s",
- # (new_word_num, slice_total, doc_id)
- # )
- # 删除切片
- # cursor.execute(
- # "DELETE FROM slice_info WHERE slice_id = %s",
- # (chunk_id,)
- # )
- cursor.execute(
- "UPDATE slice_info SET del_flag = 1 WHERE slice_id = %s",
- (chunk_id,)
- )
-
- # # 更新知识库总长度
- # cursor.execute(
- # "UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s",
- # (new_knowledge_word_num, knowledge_id)
- # )
-
- connection.commit()
- logger.info("bm_document 和 bm_knowledge 数据更新成功")
- return True, "success"
-
- except Error as e:
- connection.rollback()
- logger.error(f"update_total_doc_len 数据库操作出错:{e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def query_slice_info_by_doc_ids(self, knowledge_id: str, doc_ids: list):
- """
- 根据 knowledge_id 和 doc_ids 列表查询 slice_info 表中的数据
-
- 参数:
- knowledge_id: 知识库ID
- doc_ids: 文档ID列表
- 返回:
- (success, data_or_error): 成功时返回数据列表,失败时返回错误信息
- """
- if not doc_ids:
- return True, []
-
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor(dictionary=True)
- placeholders = ','.join(['%s'] * len(doc_ids))
- sql = f"""
- SELECT slice_id, knowledge_id, document_id, slice_text, old_slice_text,
- create_time, update_time, slice_index
- FROM slice_info
- WHERE knowledge_id = %s AND document_id IN ({placeholders})
- """
- cursor.execute(sql, (knowledge_id, *doc_ids))
- results = cursor.fetchall()
- logger.info(f"查询 slice_info 成功,共 {len(results)} 条记录")
- return True, results
- except Error as e:
- logger.error(f"查询 slice_info 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
-
- def query_media_replacement_by_doc_ids(self, knowledge_id: str, doc_ids: list):
- """
- 根据 knowledge_id 和 doc_ids 列表查询 bm_media_replacement 表中的数据
-
- 参数:
- knowledge_id: 知识库ID
- doc_ids: 文档ID列表
- 返回:
- (success, data_or_error): 成功时返回数据列表,失败时返回错误信息
- """
- if not doc_ids:
- return True, []
-
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor(dictionary=True)
- placeholders = ','.join(['%s'] * len(doc_ids))
- sql = f"""
- SELECT knowledge_id, document_id, origin_text, media_type, media_url, create_time
- FROM bm_media_replacement
- WHERE knowledge_id = %s AND document_id IN ({placeholders})
- """
- cursor.execute(sql, (knowledge_id, *doc_ids))
- results = cursor.fetchall()
- logger.info(f"查询 bm_media_replacement 成功,共 {len(results)} 条记录")
- return True, results
- except Error as e:
- logger.error(f"查询 bm_media_replacement 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def query_bm_document_by_doc_ids(self, knowledge_id: str, doc_ids: list):
- """
- 根据 knowledge_id 和 doc_ids 列表查询 bm_document 表中的数据
-
- 参数:
- knowledge_id: 知识库ID
- doc_ids: 文档ID列表
- 返回:
- (success, data_or_error): 成功时返回数据列表,失败时返回错误信息
- """
- if not doc_ids:
- return True, []
-
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor(dictionary=True)
- placeholders = ','.join(['%s'] * len(doc_ids))
- sql = f"""
- SELECT * FROM bm_document
- WHERE knowledge_id = %s AND document_id IN ({placeholders})
- """
- cursor.execute(sql, (knowledge_id, *doc_ids))
- results = cursor.fetchall()
- logger.info(f"查询 bm_document 成功,共 {len(results)} 条记录")
- return True, results
- except Error as e:
- logger.error(f"查询 bm_document 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def copy_docs_metadata_to_new_knowledge(self, source_knowledge_id: str, source_doc_ids: list, new_knowledge_id: str, doc_id_mapping: dict = None, chunk_id_mapping: dict = None, tenant_id: str = "000000"):
- """
- 复制文档元数据到新知识库(bm_document、slice_info 和 bm_media_replacement)
-
- 参数:
- source_knowledge_id: 源知识库ID
- source_doc_ids: 源文档ID列表
- new_knowledge_id: 新的知识库ID
- doc_id_mapping: 旧文档ID到新文档ID的映射 {old_doc_id: new_doc_id}
- 如果不提供,将为每个文档生成新的雪花算法ID
- chunk_id_mapping: 旧切片ID到新切片ID的映射 {old_chunk_id: new_chunk_id}
- 用于 slice_id,保持与向量库一致
- tenant_id: 租户ID(由前端传入)
-
- 返回:
- (success, result_dict): result_dict 包含复制结果详情
- """
- if not source_doc_ids:
- return True, {"message": "无需复制的文档", "doc_id_mapping": {}}
-
- # 如果没有提供映射,为每个源文档生成新的 doc_id
- if doc_id_mapping is None:
- doc_id_mapping = {}
- for old_doc_id in source_doc_ids:
- doc_id_mapping[old_doc_id] = generate_snowflake_id()
-
- # 1. 查询源 slice_info 数据
- success, slice_data = self.query_slice_info_by_doc_ids(source_knowledge_id, source_doc_ids)
- if not success:
- return False, {"error": f"查询切片数据失败: {slice_data}"}
-
- # 2. 查询源 bm_media_replacement 数据
- success, media_data = self.query_media_replacement_by_doc_ids(source_knowledge_id, source_doc_ids)
- if not success:
- return False, {"error": f"查询媒体映射失败: {media_data}"}
-
- # 3. 查询源 bm_document 数据
- success, doc_data = self.query_bm_document_by_doc_ids(source_knowledge_id, source_doc_ids)
- if not success:
- return False, {"error": f"查询文档数据失败: {doc_data}"}
-
- connection, info = self.get_connection()
- if not connection:
- return False, {"error": info}
-
- cursor = None
- try:
- cursor = connection.cursor()
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
-
- slice_count = 0
- media_count = 0
- doc_count = 0
-
- # 4. 插入 bm_document
- if doc_data:
- # 定义所有字段及特殊处理
- all_fields = [
- 'document_id', 'knowledge_id', 'custom_separator', 'sentence_size', 'length',
- 'word_num', 'slice_total', 'name', 'url', 'parse_image', 'tenant_id',
- 'create_dept', 'create_by', 'create_time', 'update_by', 'update_time',
- 'remark', 'parsing_type', 'oss_id', 'status', 'mark_oss_id', 'mark_url',
- 'ref_document_id', 'qa_checked', 'related_questions_enabled',
- 'summary_generation_enabled', 'parent_generation_enabled', 'suffix', 'pdf_url'
- ]
- doc_values = []
- for row in doc_data:
- old_doc_id = row['document_id']
- new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id())
- # 特殊处理字段
- special = {
- 'document_id': new_doc_id,
- 'knowledge_id': new_knowledge_id,
- 'tenant_id': tenant_id,
- 'create_time': date_now,
- 'ref_document_id': old_doc_id
- }
- doc_values.append(tuple(special.get(f, row.get(f)) for f in all_fields))
-
- placeholders = ','.join(['%s'] * len(all_fields))
- doc_sql = f"INSERT INTO bm_document ({','.join(all_fields)}) VALUES ({placeholders})"
- cursor.executemany(doc_sql, doc_values)
- doc_count = len(doc_values)
- logger.info(f"插入文档数据成功: {doc_count} 条")
-
- # 5. 插入 slice_info(使用 old_slice_text 存储原始文本)
- if slice_data:
- slice_values = []
-
- for row in slice_data:
- old_doc_id = row['document_id']
- old_slice_id = row['slice_id']
- new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id())
- # 使用向量库中的 chunk_id 作为 slice_id(保持一致)
- new_slice_id = chunk_id_mapping.get(old_slice_id, generate_snowflake_id()) if chunk_id_mapping else generate_snowflake_id()
-
- # old_slice_text 存储原始文本
- old_slice_text = row.get('old_slice_text') or row['slice_text']
-
- slice_values.append((
- new_slice_id,
- new_knowledge_id,
- new_doc_id,
- row['slice_text'],
- old_slice_text,
- date_now,
- row.get('slice_index')
- ))
-
- slice_sql = """
- INSERT INTO slice_info (
- slice_id, knowledge_id, document_id, slice_text, old_slice_text, create_time, slice_index
- ) VALUES (%s, %s, %s, %s, %s, %s, %s)
- """
-
- cursor.executemany(slice_sql, slice_values)
- slice_count = len(slice_values)
- logger.info(f"插入切片数据成功: {slice_count} 条")
-
- # 6. 插入 bm_media_replacement
- if media_data:
- media_values = []
- for row in media_data:
- old_doc_id = row['document_id']
- new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id())
-
- media_values.append((
- new_knowledge_id,
- new_doc_id,
- row['origin_text'],
- row['media_type'],
- row['media_url'],
- date_now
- ))
-
- media_sql = """
- INSERT INTO bm_media_replacement (
- knowledge_id, document_id, origin_text, media_type, media_url, create_time
- ) VALUES (%s, %s, %s, %s, %s, %s)
- """
- cursor.executemany(media_sql, media_values)
- media_count = len(media_values)
- logger.info(f"插入媒体映射数据成功: {media_count} 条")
-
- connection.commit()
-
- return True, {
- "message": "复制元数据成功",
- "doc_id_mapping": doc_id_mapping,
- "doc_count": doc_count,
- "slice_count": slice_count,
- "media_count": media_count
- }
-
- except Error as e:
- connection.rollback()
- logger.error(f"复制文档元数据失败: {e}")
- return False, {"error": str(e)}
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def query_slice_by_id(self, knowledge_id: str, slice_id: str):
- """
- 根据 knowledge_id 和 slice_id 查询单个切片数据
- """
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor(dictionary=True)
- sql = """
- SELECT slice_id, knowledge_id, document_id, slice_text, old_slice_text, slice_index
- FROM slice_info WHERE knowledge_id = %s AND slice_id = %s
- """
- cursor.execute(sql, (knowledge_id, slice_id))
- results = cursor.fetchall()
- return True, results
- except Error as e:
- logger.error(f"查询 slice_info 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def query_slice_by_knowledge_and_doc(self, knowledge_id: str, document_id: str):
- """
- 根据 knowledge_id 和 document_id 查询 slice_info 表中的切片数据
- """
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- cursor = connection.cursor(dictionary=True)
- sql = """
- SELECT slice_id, knowledge_id, document_id, slice_text, old_slice_text, slice_index
- FROM slice_info
- WHERE knowledge_id = %s AND document_id = %s
- """
- cursor.execute(sql, (knowledge_id, document_id))
- results = cursor.fetchall()
- logger.info(f"查询 slice_info 成功,共 {len(results)} 条记录")
- return True, results
- except Error as e:
- logger.error(f"查询 slice_info 出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
- def update_slice_llm_fields(self, knowledge_id: str, slice_id: str, qa: str = None, question: str = None, summary: str = None):
- """
- 更新 slice_info 表中的 qa、question、summary 字段,值为 None 的字段不更新
- 基于 knowledge_id 和 slice_id 的包含关系进行更新
- """
- connection, info = self.get_connection()
- if not connection:
- return False, info
- cursor = None
- try:
- # 动态构建 SET 子句,只更新非 None 字段
- set_parts = []
- params = []
- if qa is not None:
- set_parts.append("qa = %s")
- params.append(qa)
- if question is not None:
- set_parts.append("question = %s")
- params.append(question)
- if summary is not None:
- set_parts.append("summary = %s")
- params.append(summary)
-
- if not set_parts:
- return True, "no fields to update"
-
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- set_parts.append("update_time = %s")
- params.append(date_now)
- params.append(knowledge_id)
- params.append(slice_id)
-
- cursor = connection.cursor()
- sql = f"UPDATE slice_info SET {', '.join(set_parts)} WHERE knowledge_id = %s AND slice_id = %s"
- cursor.execute(sql, tuple(params))
- connection.commit()
- logger.info(f"更新切片 {slice_id} 的LLM字段成功")
- return True, "success"
- except Error as e:
- connection.rollback()
- logger.error(f"更新 slice_info LLM字段出错: {e}")
- return False, str(e)
- finally:
- if cursor:
- cursor.close()
- if connection:
- connection.close()
|