db.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. from rag.vector_db.milvus_vector import HybridRetriever
  2. from response_info import generate_message, generate_response
  3. from utils.get_logger import setup_logger
  4. from datetime import datetime
  5. from uuid import uuid1
  6. import mysql.connector
  7. from mysql.connector import pooling, Error, errors
  8. import threading
  9. from concurrent.futures import ThreadPoolExecutor, TimeoutError
  10. from config import milvus_uri, mysql_config
  11. logger = setup_logger(__name__)
  12. # uri = "http://localhost:19530"
  13. # if 'POOL' not in globals():
  14. # try:
  15. # POOL = pooling.MySQLConnectionPool(
  16. # pool_name="mysql_pool",
  17. # pool_size=10,
  18. # **mysql_config
  19. # )
  20. # logger.info("MySQL 连接池初始化成功")
  21. # except Error as e:
  22. # logger.info(f"初始化 MySQL 连接池失败: {e}")
  23. # POOL = None
  24. class MilvusOperate:
  25. def __init__(self, collection_name: str = "default", embedding_name:str = "e5"):
  26. self.collection = collection_name
  27. self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name)
  28. self.mysql_client = MysqlOperate()
  29. def _has_collection(self):
  30. is_collection = self.hybrid_retriever.has_collection()
  31. return is_collection
  32. def _create_collection(self):
  33. if self._has_collection():
  34. resp = {"code": 400, "message": "数据库已存在"}
  35. else:
  36. create_result = self.hybrid_retriever.build_collection()
  37. resp = generate_message(create_result)
  38. return resp
  39. def _delete_collection(self):
  40. delete_result = self.hybrid_retriever.delete_collection(self.collection)
  41. resp = generate_message(delete_result)
  42. return resp
  43. def _put_by_id(self, slice_json):
  44. slice_id = slice_json.get("slice_id", None)
  45. slice_text = slice_json.get("slice_text", None)
  46. update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text)
  47. if update_result.endswith("success"):
  48. # 如果成功,更新mysql中知识库总长度和文档长度
  49. update_json = {}
  50. update_json["knowledge_id"] = slice_json.get("knowledge_id")
  51. update_json["doc_id"] = slice_json.get("document_id")
  52. update_json["chunk_len"] = chunk_len
  53. update_json["operate"] = "update"
  54. update_json["chunk_id"] = slice_id
  55. update_json["chunk_text"] = slice_text
  56. update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
  57. else:
  58. update_flag = False
  59. if not update_flag:
  60. update_result = "update_error"
  61. resp = generate_message(update_result)
  62. return resp
  63. def _insert_slice(self, slice_json):
  64. slice_id = str(uuid1())
  65. knowledge_id = slice_json.get("knowledge_id")
  66. doc_id = slice_json.get("document_id")
  67. slice_text = slice_json.get("slice_text", None)
  68. doc_name = slice_json.get("doc_name")
  69. chunk_len = len(slice_text)
  70. metadata = {
  71. "content": slice_text,
  72. "doc_id": doc_id,
  73. "chunk_id": slice_id,
  74. "metadata": {"source": doc_name, "chunk_len": chunk_len}
  75. }
  76. insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata)
  77. if insert_flag:
  78. # 如果成功,更新mysql中知识库总长度和文档长度
  79. update_json = {}
  80. update_json["knowledge_id"] = slice_json.get("knowledge_id")
  81. update_json["doc_id"] = slice_json.get("document_id")
  82. update_json["chunk_len"] = chunk_len
  83. update_json["operate"] = "insert"
  84. update_json["chunk_id"] = slice_id
  85. update_json["chunk_text"] = slice_text
  86. update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
  87. else:
  88. logger.error(f"插入向量库出错:{insert_str}")
  89. update_flag = False
  90. update_str = "向量库写入出错"
  91. # pass
  92. if not update_flag:
  93. logger.error(f"新增切片中mysql数据库出错:{update_str}")
  94. insert_result = "insert_error"
  95. else:
  96. insert_result = "insert_success"
  97. resp = generate_message(insert_result)
  98. return resp
  99. def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id):
  100. logger.info(f"删除的切片id:{chunk_id}")
  101. delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id)
  102. if delete_result.endswith("success"):
  103. chunk_len = delete_chunk_len[0]
  104. update_json = {
  105. "knowledge_id": knowledge_id,
  106. "doc_id": document_id,
  107. "chunk_len": -chunk_len,
  108. "operate": "delete",
  109. "chunk_id": chunk_id
  110. }
  111. update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
  112. else:
  113. logger.error("根据chunk id删除向量库失败")
  114. update_flag = False
  115. update_str = "根据chunk id删除失败"
  116. if not update_flag:
  117. logger.error(update_str)
  118. delete_result = "delete_error"
  119. resp = generate_message(delete_result)
  120. return resp
  121. def _delete_by_doc_id(self, doc_id: str = None):
  122. logger.info(f"删除数据的id:{doc_id}")
  123. delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id)
  124. resp = generate_message(delete_result)
  125. return resp
  126. def _search_by_chunk_id(self, chunk_id):
  127. if self._has_collection():
  128. query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id)
  129. else:
  130. query_result = []
  131. logger.info(f"根据切片查询到的信息:{query_result}")
  132. resp = generate_response(query_result)
  133. return resp
  134. def _search_by_chunk_id_list(self, chunk_id_list):
  135. if self._has_collection():
  136. query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list)
  137. else:
  138. query_result = []
  139. logger.info(f"召回的切片列表查询切片信息:{query_result}")
  140. chunk_content_list = []
  141. for chunk_dict in query_result:
  142. chunk_content = chunk_dict.get("content")
  143. chunk_content_list.append(chunk_content)
  144. return chunk_content_list
  145. def _search_by_key_word(self, search_json):
  146. if self._has_collection():
  147. doc_id = search_json.get("document_id", None)
  148. text = search_json.get("text", None)
  149. page_num = search_json.get("pageNum", 1)
  150. page_size = search_json.get("pageSize", 10)
  151. page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库
  152. query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text)
  153. else:
  154. query_result = []
  155. resp = generate_response(query_result,page_num,page_size)
  156. return resp
  157. def _insert_data(self, docs):
  158. for doc in docs:
  159. chunk = doc.get("content")
  160. insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc)
  161. if not insert_flag:
  162. break
  163. resp = insert_flag
  164. return resp, insert_info
  165. def _batch_insert_data(self, docs, text_lists):
  166. insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs)
  167. resp = insert_flag
  168. return resp, insert_info
  169. def _search(self, query, k, mode):
  170. search_result = self.hybrid_retriever.search(query, k, mode)
  171. return search_result
  172. # class MysqlOperate:
  173. # def get_connection(self):
  174. # """
  175. # 从连接池中获取一个连接
  176. # :return: 数据库连接对象
  177. # """
  178. # # try:
  179. # # with ThreadPoolExecutor() as executor:
  180. # # future = executor.submit(POOL.get_connection)
  181. # # connection = future.result(timeout=5.0) # 设置超时时间为5秒
  182. # # logger.info("成功从连接池获取连接")
  183. # # return connection, "success"
  184. # # except TimeoutError:
  185. # # logger.error("获取mysql数据库连接池超时")
  186. # # return None, "mysql获取连接池超时"
  187. # # except errors.InterfaceError as e:
  188. # # logger.error(f"MySQL 接口异常:{e}")
  189. # # return None, "mysql接口异常"
  190. # # except errors.OperationalError as e:
  191. # # logger.error(f"MySQL 操作错误:{e}")
  192. # # return None, "mysql 操作错误"
  193. # # except Error as e:
  194. # # logger.error(f"无法从连接池获取连接: {e}")
  195. # # return None, str(e)
  196. # connection = None
  197. # event = threading.Event()
  198. # def target():
  199. # nonlocal connection
  200. # try:
  201. # connection = POOL.get_connection()
  202. # finally:
  203. # event.set()
  204. # thread = threading.Thread(target=target)
  205. # thread.start()
  206. # event.wait(timeout=5)
  207. # if thread.is_alive():
  208. # # 超时处理
  209. # logger.error("获取连接超时")
  210. # return None, "获取连接超时"
  211. # else:
  212. # if connection:
  213. # return connection, "success"
  214. # else:
  215. # logger.error("获取连接失败")
  216. # return None, "获取连接失败"
  217. # def insert_to_slice(self, docs, knowledge_id, doc_id):
  218. # """
  219. # 插入数据到切片信息表中 slice_info
  220. # """
  221. # connection = None
  222. # cursor = None
  223. # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  224. # values = []
  225. # connection, cennction_info = self.get_connection()
  226. # if not connection:
  227. # return False, cennction_info
  228. # for chunk in docs:
  229. # slice_id = chunk.get("chunk_id")
  230. # slice_text = chunk.get("content")
  231. # chunk_index = chunk.get("metadata").get("chunk_index")
  232. # values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index))
  233. # try:
  234. # cursor = connection.cursor()
  235. # # insert_sql = """
  236. # # INSERT INTO slice_info (
  237. # # slice_id,
  238. # # knowledge_id,
  239. # # document_id,
  240. # # slice_text,
  241. # # create_time,
  242. # # slice_index
  243. # # ) VALUES (%s, %s, %s, %s, %s,%s)
  244. # # """
  245. # # 容错“for key 'UK_ID_TYPE_KEY'”
  246. # insert_sql = """
  247. # INSERT INTO slice_info (
  248. # slice_id,
  249. # knowledge_id,
  250. # document_id,
  251. # slice_text,
  252. # create_time,
  253. # slice_index
  254. # ) VALUES (%s, %s, %s, %s, %s, %s)
  255. # ON DUPLICATE KEY UPDATE
  256. # slice_text = VALUES(slice_text),
  257. # create_time = VALUES(create_time),
  258. # slice_index = VALUES(slice_index)
  259. # """
  260. # cursor.executemany(insert_sql, values)
  261. # connection.commit()
  262. # logger.info(f"批量插入切片数据成功。")
  263. # return True, "success"
  264. # except Error as e:
  265. # logger.error(f"数据库操作出错:{e}")
  266. # connection.rollback()
  267. # return False, str(e)
  268. # finally:
  269. # # if cursor:
  270. # cursor.close()
  271. # # if connection and connection.is_connected():
  272. # connection.close()
  273. # def delete_to_slice(self, doc_id):
  274. # """
  275. # 删除 slice_info库中切片信息
  276. # """
  277. # connection = None
  278. # cursor = None
  279. # connection, connection_info = self.get_connection()
  280. # if not connection:
  281. # return False, connection_info
  282. # try:
  283. # cursor = connection.cursor()
  284. # delete_sql = f"DELETE FROM slice_info WHERE document_id = %s"
  285. # cursor.execute(delete_sql, (doc_id,))
  286. # connection.commit()
  287. # logger.info(f"删除数据成功")
  288. # return True, "success"
  289. # except Error as e:
  290. # logger.error(f"根据{doc_id}删除数据失败:{e}")
  291. # connection.rollback()
  292. # return False, str(e)
  293. # finally:
  294. # # if cursor:
  295. # cursor.close()
  296. # # if connection and connection.is_connected():
  297. # connection.close()
  298. # def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
  299. # """
  300. # 批量插入数据到指定表
  301. # """
  302. # connection = None
  303. # cursor = None
  304. # connection, connection_info = self.get_connection()
  305. # if not connection:
  306. # return False, connection_info
  307. # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  308. # values = []
  309. # for img_key, img_value in image_dict.items():
  310. # origin_text = img_key
  311. # media_url = img_value
  312. # values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now))
  313. # try:
  314. # cursor = connection.cursor()
  315. # # insert_sql = """
  316. # # INSERT INTO bm_media_replacement (
  317. # # knowledge_id,
  318. # # document_id,
  319. # # origin_text,
  320. # # media_type,
  321. # # media_url,
  322. # # create_time
  323. # # ) VALUES (%s, %s, %s, %s, %s, %s)
  324. # # """
  325. # # 容错“for key 'UK_ID_TYPE_KEY'”
  326. # insert_sql = """
  327. # INSERT INTO bm_media_replacement (
  328. # knowledge_id,
  329. # document_id,
  330. # origin_text,
  331. # media_type,
  332. # media_url,
  333. # create_time
  334. # ) VALUES (%s, %s, %s, %s, %s, %s)
  335. # ON DUPLICATE KEY UPDATE
  336. # origin_text = VALUES(origin_text),
  337. # media_type = VALUES(media_type),
  338. # media_url = VALUES(media_url),
  339. # create_time = VALUES(create_time)
  340. # """
  341. # cursor.executemany(insert_sql, values)
  342. # connection.commit()
  343. # logger.info(f"插入到bm_media_replacement表成功")
  344. # return True, "success"
  345. # except Error as e:
  346. # logger.error(f"数据库操作出错:{e}")
  347. # connection.rollback()
  348. # return False, str(e)
  349. # finally:
  350. # # if cursor:
  351. # cursor.close()
  352. # # if connection and connection.is_connected():
  353. # connection.close()
  354. # def delete_image_url(self, doc_id):
  355. # """
  356. # 根据doc id删除bm_media_replacement中的数据
  357. # """
  358. # connection = None
  359. # cursor = None
  360. # connection, connection_info = self.get_connection()
  361. # if not connection:
  362. # return False, connection_info
  363. # try:
  364. # cursor = connection.cursor()
  365. # delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s"
  366. # cursor.execute(delete_sql, (doc_id,))
  367. # connection.commit()
  368. # logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功")
  369. # return True, "success"
  370. # except Error as e:
  371. # logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}")
  372. # connection.rollback()
  373. # return False, str(e)
  374. # finally:
  375. # # if cursor:
  376. # cursor.close()
  377. # # if connection and connection.is_connected():
  378. # connection.close()
  379. # def update_total_doc_len(self, update_json):
  380. # """
  381. # 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息
  382. # """
  383. # knowledge_id = update_json.get("knowledge_id")
  384. # doc_id = update_json.get("doc_id")
  385. # chunk_len = update_json.get("chunk_len")
  386. # operate = update_json.get("operate")
  387. # chunk_id = update_json.get("chunk_id")
  388. # chunk_text = update_json.get("chunk_text")
  389. # connection = None
  390. # cursor = None
  391. # connection, connection_info = self.get_connection()
  392. # if not connection:
  393. # return False, connection_info
  394. # try:
  395. # cursor = connection.cursor()
  396. # query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s"
  397. # query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s"
  398. # cursor.execute(query_doc_word_num_sql, (doc_id,))
  399. # doc_result = cursor.fetchone()
  400. # logger.info(f"查询到的文档长度信息:{doc_result}")
  401. # cursor.execute(query_knowledge_word_num_sql, (knowledge_id, ))
  402. # knowledge_result = cursor.fetchone()
  403. # logger.info(f"查询到的知识库总长度信息:{knowledge_result}")
  404. # if not doc_result:
  405. # new_word_num = 0
  406. # slice_total = 0
  407. # else:
  408. # old_word_num = doc_result[0]
  409. # slice_total = doc_result[1]
  410. # new_word_num = old_word_num + chunk_len
  411. # slice_total -= 1 if slice_total else 0
  412. # if not knowledge_result:
  413. # new_knowledge_word_num = 0
  414. # else:
  415. # old_knowledge_word_num = knowledge_result[0]
  416. # new_knowledge_word_num = old_knowledge_word_num + chunk_len
  417. # if operate == "update":
  418. # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
  419. # cursor.execute(update_sql, (new_word_num, doc_id))
  420. # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  421. # update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s"
  422. # cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id))
  423. # elif operate == "insert":
  424. # query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s"
  425. # cursor.execute(query_slice_info_index_sql, (doc_id,))
  426. # chunk_index_result = cursor.fetchone()[0]
  427. # # logger.info(chunk_index_result)
  428. # if chunk_index_result:
  429. # chunk_max_index = int(chunk_index_result)
  430. # else:
  431. # chunk_max_index = 0
  432. # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
  433. # cursor.execute(update_sql, (new_word_num, doc_id))
  434. # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  435. # insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)"
  436. # cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1))
  437. # else:
  438. # update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s"
  439. # cursor.execute(update_sql, (new_word_num, slice_total, doc_id))
  440. # # 删除切片id对应的切片
  441. # delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s"
  442. # cursor.execute(delete_slice_sql, (chunk_id, ))
  443. # update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s"
  444. # cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id))
  445. # connection.commit()
  446. # logger.info("bm_document和bm_knowledge数据更新成功")
  447. # return True, "success"
  448. # except Error as e:
  449. # logger.error(f"数据库操作出错:{e}")
  450. # connection.rollback()
  451. # return False, str(e)
  452. # finally:
  453. # # if cursor:
  454. # cursor.close()
  455. # # if connection and connection.is_connected():
  456. # connection.close()
  457. import time
  458. # ========== 全局初始化连接池(自动检测 + 超时保护) ==========
  459. class SafeMySQLPool:
  460. def __init__(self, pool_size=10, conn_timeout=10, idle_timeout=60, **mysql_config):
  461. mysql_config.setdefault("connect_timeout", conn_timeout)
  462. mysql_config.setdefault("pool_reset_session", True)
  463. self._pool = pooling.MySQLConnectionPool(
  464. pool_name="safe_mysql_pool",
  465. pool_size=pool_size,
  466. **mysql_config
  467. )
  468. self._lock = threading.Lock()
  469. self._active_conns = {} # {id(conn): (conn, last_used_time)}
  470. self._idle_timeout = idle_timeout
  471. self._stop_event = threading.Event()
  472. threading.Thread(target=self._auto_reclaimer, daemon=True).start()
  473. def get_connection(self, timeout=10):
  474. """安全获取连接(带超时检测与追踪)"""
  475. start = time.time()
  476. while True:
  477. try:
  478. conn = self._pool.get_connection()
  479. with self._lock:
  480. self._active_conns[id(conn)] = (conn, time.time())
  481. return self._wrap_connection(conn)
  482. except errors.PoolError:
  483. if time.time() - start > timeout:
  484. raise TimeoutError(f"获取 MySQL 连接超时(超过 {timeout}s)")
  485. time.sleep(0.3)
  486. def _wrap_connection(self, conn):
  487. """包装连接对象以监控关闭事件"""
  488. pool = self
  489. orig_close = conn.close
  490. def safe_close():
  491. try:
  492. orig_close()
  493. finally:
  494. with pool._lock:
  495. pool._active_conns.pop(id(conn), None)
  496. conn.close = safe_close
  497. return conn
  498. def _auto_reclaimer(self):
  499. """后台线程自动回收超时未关闭连接"""
  500. while not self._stop_event.is_set():
  501. time.sleep(5)
  502. now = time.time()
  503. with self._lock:
  504. to_remove = []
  505. for cid, (conn, last_used) in list(self._active_conns.items()):
  506. if now - last_used > self._idle_timeout:
  507. try:
  508. conn.close()
  509. print(f"[AutoReclaim] 已回收超时未关闭连接 (idle={int(now - last_used)}s)")
  510. except Exception as e:
  511. print(f"[AutoReclaim] 回收连接失败: {e}")
  512. to_remove.append(cid)
  513. for cid in to_remove:
  514. self._active_conns.pop(cid, None)
  515. def close_all(self):
  516. """停止守护线程并关闭所有连接"""
  517. self._stop_event.set()
  518. with self._lock:
  519. for conn, _ in self._active_conns.values():
  520. try:
  521. conn.close()
  522. except:
  523. pass
  524. self._active_conns.clear()
  525. # ========== 初始化连接池 ==========
  526. if "POOL" not in globals():
  527. try:
  528. POOL = SafeMySQLPool(pool_size=10, idle_timeout=60, **mysql_config)
  529. logger.info("MySQL 连接池初始化成功")
  530. except Error as e:
  531. logger.error(f"MySQL 连接池初始化失败: {e}")
  532. POOL = None
  533. # ========== MysqlOperate 类 ==========
  534. class MysqlOperate:
  535. def get_connection(self):
  536. """安全获取连接"""
  537. if not POOL:
  538. return None, "连接池未初始化"
  539. try:
  540. connection = POOL.get_connection(timeout=5)
  541. return connection, "success"
  542. except TimeoutError as e:
  543. logger.error(str(e))
  544. return None, "获取连接超时"
  545. except Error as e:
  546. logger.error(f"MySQL 获取连接失败: {e}")
  547. return None, str(e)
  548. def _execute_many(self, sql, values, success_msg, err_msg):
  549. """通用批量执行模板"""
  550. connection, info = self.get_connection()
  551. if not connection:
  552. return False, info
  553. cursor = None
  554. try:
  555. cursor = connection.cursor()
  556. cursor.executemany(sql, values)
  557. connection.commit()
  558. logger.info(success_msg)
  559. return True, "success"
  560. except Error as e:
  561. connection.rollback()
  562. logger.error(f"{err_msg}: {e}")
  563. return False, str(e)
  564. finally:
  565. if cursor:
  566. cursor.close()
  567. connection.close()
  568. def insert_to_slice(self, docs, knowledge_id, doc_id):
  569. """批量插入切片信息"""
  570. date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  571. values = [
  572. (
  573. chunk.get("chunk_id"),
  574. knowledge_id,
  575. doc_id,
  576. chunk.get("content"),
  577. date_now,
  578. chunk.get("metadata", {}).get("chunk_index")
  579. )
  580. for chunk in docs
  581. ]
  582. sql = """
  583. INSERT INTO slice_info (
  584. slice_id, knowledge_id, document_id, slice_text, create_time, slice_index
  585. ) VALUES (%s, %s, %s, %s, %s, %s)
  586. ON DUPLICATE KEY UPDATE
  587. slice_text = VALUES(slice_text),
  588. create_time = VALUES(create_time),
  589. slice_index = VALUES(slice_index)
  590. """
  591. return self._execute_many(sql, values, "批量插入切片数据成功", "插入 slice_info 出错")
  592. def delete_to_slice(self, doc_id):
  593. """删除切片"""
  594. connection, info = self.get_connection()
  595. if not connection:
  596. return False, info
  597. cursor = None
  598. try:
  599. cursor = connection.cursor()
  600. cursor.execute("DELETE FROM slice_info WHERE document_id = %s", (doc_id,))
  601. connection.commit()
  602. logger.info(f"删除 slice_info 数据成功")
  603. return True, "success"
  604. except Error as e:
  605. connection.rollback()
  606. logger.error(f"删除 slice_info 出错: {e}")
  607. return False, str(e)
  608. finally:
  609. if cursor:
  610. cursor.close()
  611. connection.close()
  612. def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
  613. """插入图片映射表"""
  614. date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  615. values = [
  616. (knowledge_id, doc_id, k, "image", v, date_now)
  617. for k, v in image_dict.items()
  618. ]
  619. sql = """
  620. INSERT INTO bm_media_replacement (
  621. knowledge_id, document_id, origin_text, media_type, media_url, create_time
  622. ) VALUES (%s, %s, %s, %s, %s, %s)
  623. ON DUPLICATE KEY UPDATE
  624. origin_text = VALUES(origin_text),
  625. media_type = VALUES(media_type),
  626. media_url = VALUES(media_url),
  627. create_time = VALUES(create_time)
  628. """
  629. return self._execute_many(sql, values, "插入 bm_media_replacement 成功", "插入 bm_media_replacement 出错")
  630. def delete_image_url(self, doc_id):
  631. """删除图片映射"""
  632. connection, info = self.get_connection()
  633. if not connection:
  634. return False, info
  635. cursor = None
  636. try:
  637. cursor = connection.cursor()
  638. cursor.execute("DELETE FROM bm_media_replacement WHERE document_id = %s", (doc_id,))
  639. connection.commit()
  640. logger.info(f"删除 bm_media_replacement 成功")
  641. return True, "success"
  642. except Error as e:
  643. connection.rollback()
  644. logger.error(f"删除 bm_media_replacement 出错: {e}")
  645. return False, str(e)
  646. finally:
  647. if cursor:
  648. cursor.close()
  649. connection.close()
  650. def delete_document(self, task_id):
  651. """删除 bm_document 表中的记录(取消任务前清理)"""
  652. connection, info = self.get_connection()
  653. if not connection:
  654. return False, info
  655. cursor = None
  656. try:
  657. cursor = connection.cursor()
  658. sql = "DELETE FROM bm_document WHERE document_id = %s"
  659. cursor.execute(sql, (task_id,))
  660. affected_rows = cursor.rowcount
  661. connection.commit()
  662. logger.info(f"删除 bm_document 记录成功: task_id={task_id}, 影响行数={affected_rows}")
  663. return True, affected_rows
  664. except Error as e:
  665. connection.rollback()
  666. logger.error(f"删除 bm_document 记录失败: {e}")
  667. return False, str(e)
  668. finally:
  669. if cursor:
  670. cursor.close()
  671. if connection:
  672. connection.close()