Просмотр исходного кода

修改mysql连接池获取和删除向量库

weiyu 6 месяцев назад
Родитель
Сommit
1b5a62fb5a
2 измененных файлов с 43 добавлено и 17 удалено
  1. 42 12
      rag/db.py
  2. 1 5
      rag/documents_process.py

+ 42 - 12
rag/db.py

@@ -5,6 +5,7 @@ from datetime import datetime
 from uuid import uuid1
 import mysql.connector
 from mysql.connector import pooling, Error
+import threading
 from concurrent.futures import ThreadPoolExecutor, TimeoutError
 from config import milvus_uri, mysql_config
 
@@ -197,19 +198,48 @@ class MysqlOperate:
         从连接池中获取一个连接
         :return: 数据库连接对象
         """
-        try:
-            with ThreadPoolExecutor() as executor:
-                future = executor.submit(POOL.get_connection)
-                connection = future.result(timeout=5.0)  # 设置超时时间为5秒
-
-                logger.info("成功从连接池获取连接")
+        # try:
+            # with ThreadPoolExecutor() as executor:
+            #     future = executor.submit(POOL.get_connection)
+            #     connection = future.result(timeout=5.0)  # 设置超时时间为5秒
+            # logger.info("成功从连接池获取连接")
+            # return connection, "success"
+        # except TimeoutError:
+        #     logger.error("获取mysql数据库连接池超时")
+        #     return None, "mysql获取连接池超时"
+        # except errors.InterfaceError as e:
+        #     logger.error(f"MySQL 接口异常:{e}")
+        #     return None, "mysql接口异常"
+        # except errors.OperationalError as e:
+        #     logger.error(f"MySQL 操作错误:{e}")
+        #     return None, "mysql 操作错误"
+        # except Error as e:
+        #     logger.error(f"无法从连接池获取连接: {e}")
+        #     return None, str(e)
+        connection = None
+        event = threading.Event()
+
+        def target():
+            nonlocal connection
+            try:
+                connection = POOL.get_connection()
+            finally:
+                event.set()
+
+        thread = threading.Thread(target=target)
+        thread.start()
+        event.wait(timeout=5)
+
+        if thread.is_alive():
+            # 超时处理
+            logger.error("获取连接超时")
+            return None, "获取连接超时"
+        else:
+            if connection:
                 return connection, "success"
-        except TimeoutError:
-            logger.error("获取mysql数据库连接池超时")
-            return None, "mysql获取连接池超时"
-        except Error as e:
-            logger.error(f"无法从连接池获取连接: {e}")
-            return None, str(e)
+            else:
+                logger.error("获取连接失败")
+                return None, "获取连接失败"
 
     def insert_to_slice(self, docs, knowledge_id, doc_id):
         """

+ 1 - 5
rag/documents_process.py

@@ -324,8 +324,6 @@ class ProcessDocuments():
                     # 插入到mysql的slice info数据库中
                     insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, doc_id)
                 else:
-                    # resp = {"code": 500, "message": insert_milvus_str}
-                    # return resp
                     insert_slice_flag = False
                     parse_file_status = False
 
@@ -333,15 +331,13 @@ class ProcessDocuments():
                     # 插入mysql中的bm_media_replacement表中
                     insert_img_flag, insert_mysql_info =  self.mysql_client.insert_to_image_url(flag_img_info, self.knowledge_id, doc_id)
                 else:
-                    # resp = {"code": 500, "message": insert_mysql_info}
-                    self.milvus_client._delete_by_doc_id(doc_id=doc_id)
+                    # self.milvus_client._delete_by_doc_id(doc_id=doc_id)
                     insert_img_flag = False
 
                     # return resp
                     parse_file_status = False
 
                 if insert_img_flag:
-                    # resp = {"code": 200, "message": "文档解析成功"}
                     parse_file_status = True
                 
                 else: