db.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. from rag.vector_db.milvus_vector import HybridRetriever
  2. from response_info import generate_message, generate_response
  3. from utils.get_logger import setup_logger
  4. from datetime import datetime
  5. from uuid import uuid1
  6. import mysql.connector
  7. from mysql.connector import pooling, Error
  8. from concurrent.futures import ThreadPoolExecutor, TimeoutError
  9. from config import milvus_uri, mysql_config
  10. logger = setup_logger(__name__)
  11. # uri = "http://localhost:19530"
  12. try:
  13. POOL = pooling.MySQLConnectionPool(
  14. pool_name="mysql_pool",
  15. pool_size=10,
  16. **mysql_config
  17. )
  18. logger.info("MySQL 连接池初始化成功")
  19. except Error as e:
  20. logger.info(f"初始化 MySQL 连接池失败: {e}")
  21. POOL = None
  22. class MilvusOperate:
  23. def __init__(self, collection_name: str = "default", embedding_name:str = "e5"):
  24. self.collection = collection_name
  25. self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name)
  26. self.mysql_client = MysqlOperate()
  27. def _has_collection(self):
  28. is_collection = self.hybrid_retriever.has_collection()
  29. return is_collection
  30. def _create_collection(self):
  31. if self._has_collection():
  32. resp = {"code": 400, "message": "数据库已存在"}
  33. else:
  34. create_result = self.hybrid_retriever.build_collection()
  35. resp = generate_message(create_result)
  36. return resp
  37. def _delete_collection(self):
  38. delete_result = self.hybrid_retriever.delete_collection(self.collection)
  39. resp = generate_message(delete_result)
  40. return resp
  41. def _put_by_id(self, slice_json):
  42. slice_id = slice_json.get("slice_id", None)
  43. slice_text = slice_json.get("slice_text", None)
  44. update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text)
  45. if update_result.endswith("success"):
  46. # 如果成功,更新mysql中知识库总长度和文档长度
  47. update_json = {}
  48. update_json["knowledge_id"] = slice_json.get("knowledge_id")
  49. update_json["doc_id"] = slice_json.get("document_id")
  50. update_json["chunk_len"] = chunk_len
  51. update_json["operate"] = "update"
  52. update_json["chunk_id"] = slice_id
  53. update_json["chunk_text"] = slice_text
  54. update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
  55. else:
  56. update_flag = False
  57. if not update_flag:
  58. update_result = "update_error"
  59. resp = generate_message(update_result)
  60. return resp
  61. def _insert_slice(self, slice_json):
  62. slice_id = str(uuid1())
  63. knowledge_id = slice_json.get("knowledge_id")
  64. doc_id = slice_json.get("document_id")
  65. slice_text = slice_json.get("slice_text", None)
  66. doc_name = slice_json.get("doc_name")
  67. chunk_len = len(slice_text)
  68. metadata = {
  69. "content": slice_text,
  70. "doc_id": doc_id,
  71. "chunk_id": slice_id,
  72. "metadata": {"source": doc_name, "chunk_len": chunk_len}
  73. }
  74. insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata)
  75. if insert_flag:
  76. # 如果成功,更新mysql中知识库总长度和文档长度
  77. update_json = {}
  78. update_json["knowledge_id"] = slice_json.get("knowledge_id")
  79. update_json["doc_id"] = slice_json.get("document_id")
  80. update_json["chunk_len"] = chunk_len
  81. update_json["operate"] = "insert"
  82. update_json["chunk_id"] = slice_id
  83. update_json["chunk_text"] = slice_text
  84. update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
  85. else:
  86. logger.error(f"插入向量库出错:{insert_str}")
  87. update_flag = False
  88. update_str = "向量库写入出错"
  89. # pass
  90. if not update_flag:
  91. logger.error(f"新增切片中mysql数据库出错:{update_str}")
  92. insert_result = "insert_error"
  93. else:
  94. insert_result = "insert_success"
  95. resp = generate_message(insert_result)
  96. return resp
  97. def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id):
  98. logger.info(f"删除的切片id:{chunk_id}")
  99. delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id)
  100. if delete_result.endswith("success"):
  101. chunk_len = delete_chunk_len[0]
  102. update_json = {
  103. "knowledge_id": knowledge_id,
  104. "doc_id": document_id,
  105. "chunk_len": -chunk_len,
  106. "operate": "delete",
  107. "chunk_id": chunk_id
  108. }
  109. update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
  110. else:
  111. logger.error("根据chunk id删除向量库失败")
  112. update_flag = False
  113. update_str = "根据chunk id删除失败"
  114. if not update_flag:
  115. logger.error(update_str)
  116. delete_result = "delete_error"
  117. resp = generate_message(delete_result)
  118. return resp
  119. def _delete_by_doc_id(self, doc_id: str = None):
  120. logger.info(f"删除数据的id:{doc_id}")
  121. delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id)
  122. resp = generate_message(delete_result)
  123. return resp
  124. def _search_by_chunk_id(self, chunk_id):
  125. if self._has_collection():
  126. query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id)
  127. else:
  128. query_result = []
  129. logger.info(f"根据切片查询到的信息:{query_result}")
  130. resp = generate_response(query_result)
  131. return resp
  132. def _search_by_key_word(self, search_json):
  133. if self._has_collection():
  134. doc_id = search_json.get("document_id", None)
  135. text = search_json.get("text", None)
  136. page_num = search_json.get("pageNum", 1)
  137. page_size = search_json.get("pageSize", 10)
  138. page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库
  139. query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text)
  140. else:
  141. query_result = []
  142. resp = generate_response(query_result,page_num,page_size)
  143. return resp
  144. def _insert_data(self, docs):
  145. for doc in docs:
  146. chunk = doc.get("content")
  147. insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc)
  148. if not insert_flag:
  149. break
  150. resp = insert_flag
  151. return resp, insert_info
  152. def _batch_insert_data(self, docs, text_lists):
  153. insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs)
  154. resp = insert_flag
  155. return resp, insert_info
  156. def _search(self, query, k, mode):
  157. search_result = self.hybrid_retriever.search(query, k, mode)
  158. return search_result
  159. class MysqlOperate:
  160. def get_connection(self):
  161. """
  162. 从连接池中获取一个连接
  163. :return: 数据库连接对象
  164. """
  165. try:
  166. with ThreadPoolExecutor() as executor:
  167. future = executor.submit(POOL.get_connection)
  168. connection = future.result(timeout=5.0) # 设置超时时间为5秒
  169. logger.info("成功从连接池获取连接")
  170. return connection, "success"
  171. except TimeoutError:
  172. logger.error("获取mysql数据库连接池超时")
  173. return None, "mysql获取连接池超时"
  174. except Error as e:
  175. logger.error(f"无法从连接池获取连接: {e}")
  176. return None, str(e)
  177. def insert_to_slice(self, docs, knowledge_id, doc_id):
  178. """
  179. 插入数据到切片信息表中 slice_info
  180. """
  181. connection = None
  182. cursor = None
  183. date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  184. values = []
  185. connection, cennction_info = self.get_connection()
  186. if not connection:
  187. return False, cennction_info
  188. for chunk in docs:
  189. slice_id = chunk.get("chunk_id")
  190. slice_text = chunk.get("content")
  191. chunk_index = chunk.get("metadata").get("chunk_index")
  192. values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index))
  193. try:
  194. cursor = connection.cursor()
  195. insert_sql = """
  196. INSERT INTO slice_info (
  197. slice_id,
  198. knowledge_id,
  199. document_id,
  200. slice_text,
  201. create_time,
  202. slice_index
  203. ) VALUES (%s, %s, %s, %s, %s,%s)
  204. """
  205. cursor.executemany(insert_sql, values)
  206. connection.commit()
  207. logger.info(f"批量插入切片数据成功。")
  208. return True, "success"
  209. except Error as e:
  210. logger.error(f"数据库操作出错:{e}")
  211. connection.rollback()
  212. return False, str(e)
  213. finally:
  214. # if cursor:
  215. cursor.close()
  216. # if connection and connection.is_connected():
  217. connection.close()
  218. def delete_to_slice(self, doc_id):
  219. """
  220. 删除 slice_info库中切片信息
  221. """
  222. connection = None
  223. cursor = None
  224. connection, connection_info = self.get_connection()
  225. if not connection:
  226. return False, connection_info
  227. try:
  228. cursor = connection.cursor()
  229. delete_sql = f"DELETE FROM slice_info WHERE document_id = %s"
  230. cursor.execute(delete_sql, (doc_id,))
  231. connection.commit()
  232. logger.info(f"删除数据成功")
  233. return True, "success"
  234. except Error as e:
  235. logger.error(f"根据{doc_id}删除数据失败:{e}")
  236. connection.rollback()
  237. return False, str(e)
  238. finally:
  239. # if cursor:
  240. cursor.close()
  241. # if connection and connection.is_connected():
  242. connection.close()
  243. def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
  244. """
  245. 批量插入数据到指定表
  246. """
  247. connection = None
  248. cursor = None
  249. connection, connection_info = self.get_connection()
  250. if not connection:
  251. return False, connection_info
  252. date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  253. values = []
  254. for img_key, img_value in image_dict.items():
  255. origin_text = img_key
  256. media_url = img_value
  257. values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now))
  258. try:
  259. cursor = connection.cursor()
  260. insert_sql = """
  261. INSERT INTO bm_media_replacement (
  262. knowledge_id,
  263. document_id,
  264. origin_text,
  265. media_type,
  266. media_url,
  267. create_time
  268. ) VALUES (%s, %s, %s, %s, %s, %s)
  269. """
  270. cursor.executemany(insert_sql, values)
  271. connection.commit()
  272. logger.info(f"插入到bm_media_replacement表成功")
  273. return True, "success"
  274. except Error as e:
  275. logger.error(f"数据库操作出错:{e}")
  276. connection.rollback()
  277. return False, str(e)
  278. finally:
  279. # if cursor:
  280. cursor.close()
  281. # if connection and connection.is_connected():
  282. connection.close()
  283. def delete_image_url(self, doc_id):
  284. """
  285. 根据doc id删除bm_media_replacement中的数据
  286. """
  287. connection = None
  288. cursor = None
  289. connection, connection_info = self.get_connection()
  290. if not connection:
  291. return False, connection_info
  292. try:
  293. cursor = connection.cursor()
  294. delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s"
  295. cursor.execute(delete_sql, (doc_id,))
  296. connection.commit()
  297. logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功")
  298. return True, "success"
  299. except Error as e:
  300. logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}")
  301. connection.rollback()
  302. return False, str(e)
  303. finally:
  304. # if cursor:
  305. cursor.close()
  306. # if connection and connection.is_connected():
  307. connection.close()
  308. def update_total_doc_len(self, update_json):
  309. """
  310. 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息
  311. """
  312. knowledge_id = update_json.get("knowledge_id")
  313. doc_id = update_json.get("doc_id")
  314. chunk_len = update_json.get("chunk_len")
  315. operate = update_json.get("operate")
  316. chunk_id = update_json.get("chunk_id")
  317. chunk_text = update_json.get("chunk_text")
  318. connection = None
  319. cursor = None
  320. connection, connection_info = self.get_connection()
  321. if not connection:
  322. return False, connection_info
  323. try:
  324. cursor = connection.cursor()
  325. query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s"
  326. query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s"
  327. cursor.execute(query_doc_word_num_sql, (doc_id,))
  328. doc_result = cursor.fetchone()
  329. logger.info(f"查询到的文档长度信息:{doc_result}")
  330. cursor.execute(query_knowledge_word_num_sql, (knowledge_id, ))
  331. knowledge_result = cursor.fetchone()
  332. logger.info(f"查询到的知识库总长度信息:{knowledge_result}")
  333. if not doc_result:
  334. new_word_num = 0
  335. slice_total = 0
  336. else:
  337. old_word_num = doc_result[0]
  338. slice_total = doc_result[1]
  339. new_word_num = old_word_num + chunk_len
  340. slice_total -= 1 if slice_total else 0
  341. if not knowledge_result:
  342. new_knowledge_word_num = 0
  343. else:
  344. old_knowledge_word_num = knowledge_result[0]
  345. new_knowledge_word_num = old_knowledge_word_num + chunk_len
  346. if operate == "update":
  347. update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
  348. cursor.execute(update_sql, (new_word_num, doc_id))
  349. date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  350. update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s"
  351. cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id))
  352. elif operate == "insert":
  353. query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s"
  354. cursor.execute(query_slice_info_index_sql, (doc_id,))
  355. chunk_index_result = cursor.fetchone()[0]
  356. # logger.info(chunk_index_result)
  357. if chunk_index_result:
  358. chunk_max_index = int(chunk_index_result)
  359. else:
  360. chunk_max_index = 0
  361. update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
  362. cursor.execute(update_sql, (new_word_num, doc_id))
  363. date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  364. insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)"
  365. cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1))
  366. else:
  367. update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s"
  368. cursor.execute(update_sql, (new_word_num, slice_total, doc_id))
  369. # 删除切片id对应的切片
  370. delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s"
  371. cursor.execute(delete_slice_sql, (chunk_id, ))
  372. update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s"
  373. cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id))
  374. connection.commit()
  375. logger.info("bm_document和bm_knowledge数据更新成功")
  376. return True, "success"
  377. except Error as e:
  378. logger.error(f"数据库操作出错:{e}")
  379. connection.rollback()
  380. return False, str(e)
  381. finally:
  382. # if cursor:
  383. cursor.close()
  384. # if connection and connection.is_connected():
  385. connection.close()