| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- from rag.vector_db.milvus_vector import HybridRetriever
- from response_info import generate_message, generate_response
- from utils.get_logger import setup_logger
- from datetime import datetime
- from uuid import uuid1
- import mysql.connector
- from mysql.connector import pooling, Error
- import threading
- from concurrent.futures import ThreadPoolExecutor, TimeoutError
- from config import milvus_uri, mysql_config
- logger = setup_logger(__name__)
- # uri = "http://localhost:19530"
- try:
- POOL = pooling.MySQLConnectionPool(
- pool_name="mysql_pool",
- pool_size=10,
- **mysql_config
- )
- logger.info("MySQL 连接池初始化成功")
- except Error as e:
- logger.info(f"初始化 MySQL 连接池失败: {e}")
- POOL = None
- class MilvusOperate:
- def __init__(self, collection_name: str = "default", embedding_name:str = "e5"):
- self.collection = collection_name
- self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name)
- self.mysql_client = MysqlOperate()
- def _has_collection(self):
- is_collection = self.hybrid_retriever.has_collection()
- return is_collection
-
- def _create_collection(self):
- if self._has_collection():
- resp = {"code": 400, "message": "数据库已存在"}
- else:
- create_result = self.hybrid_retriever.build_collection()
- resp = generate_message(create_result)
- return resp
-
- def _delete_collection(self):
- delete_result = self.hybrid_retriever.delete_collection(self.collection)
- resp = generate_message(delete_result)
- return resp
-
- def _put_by_id(self, slice_json):
- slice_id = slice_json.get("slice_id", None)
- slice_text = slice_json.get("slice_text", None)
- update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text)
- if update_result.endswith("success"):
- # 如果成功,更新mysql中知识库总长度和文档长度
- update_json = {}
- update_json["knowledge_id"] = slice_json.get("knowledge_id")
- update_json["doc_id"] = slice_json.get("document_id")
- update_json["chunk_len"] = chunk_len
- update_json["operate"] = "update"
- update_json["chunk_id"] = slice_id
- update_json["chunk_text"] = slice_text
- update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
- else:
- update_flag = False
-
- if not update_flag:
- update_result = "update_error"
-
- resp = generate_message(update_result)
- return resp
-
- def _insert_slice(self, slice_json):
- slice_id = str(uuid1())
- knowledge_id = slice_json.get("knowledge_id")
- doc_id = slice_json.get("document_id")
- slice_text = slice_json.get("slice_text", None)
- doc_name = slice_json.get("doc_name")
- chunk_len = len(slice_text)
- metadata = {
- "content": slice_text,
- "doc_id": doc_id,
- "chunk_id": slice_id,
- "metadata": {"source": doc_name, "chunk_len": chunk_len}
- }
- insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata)
- if insert_flag:
- # 如果成功,更新mysql中知识库总长度和文档长度
- update_json = {}
- update_json["knowledge_id"] = slice_json.get("knowledge_id")
- update_json["doc_id"] = slice_json.get("document_id")
- update_json["chunk_len"] = chunk_len
- update_json["operate"] = "insert"
- update_json["chunk_id"] = slice_id
- update_json["chunk_text"] = slice_text
- update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
- else:
- logger.error(f"插入向量库出错:{insert_str}")
- update_flag = False
- update_str = "向量库写入出错"
- # pass
-
- if not update_flag:
- logger.error(f"新增切片中mysql数据库出错:{update_str}")
- insert_result = "insert_error"
- else:
- insert_result = "insert_success"
-
- resp = generate_message(insert_result)
- return resp
-
- def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id):
- logger.info(f"删除的切片id:{chunk_id}")
- delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id)
- if delete_result.endswith("success"):
- chunk_len = delete_chunk_len[0]
- update_json = {
- "knowledge_id": knowledge_id,
- "doc_id": document_id,
- "chunk_len": -chunk_len,
- "operate": "delete",
- "chunk_id": chunk_id
- }
- update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
- else:
- logger.error("根据chunk id删除向量库失败")
- update_flag = False
- update_str = "根据chunk id删除失败"
-
- if not update_flag:
- logger.error(update_str)
- delete_result = "delete_error"
-
- resp = generate_message(delete_result)
- return resp
-
- def _delete_by_doc_id(self, doc_id: str = None):
- logger.info(f"删除数据的id:{doc_id}")
- delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id)
- resp = generate_message(delete_result)
- return resp
-
- def _search_by_chunk_id(self, chunk_id):
- if self._has_collection():
- query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id)
- else:
- query_result = []
- logger.info(f"根据切片查询到的信息:{query_result}")
- resp = generate_response(query_result)
- return resp
-
- def _search_by_chunk_id_list(self, chunk_id_list):
- if self._has_collection():
- query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list)
- else:
- query_result = []
- logger.info(f"召回的切片列表查询切片信息:{query_result}")
-
- chunk_content_list = []
- for chunk_dict in query_result:
- chunk_content = chunk_dict.get("content")
- chunk_content_list.append(chunk_content)
- return chunk_content_list
-
-
- def _search_by_key_word(self, search_json):
- if self._has_collection():
- doc_id = search_json.get("document_id", None)
- text = search_json.get("text", None)
- page_num = search_json.get("pageNum", 1)
- page_size = search_json.get("pageSize", 10)
- page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库
- query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text)
- else:
- query_result = []
- resp = generate_response(query_result,page_num,page_size)
- return resp
-
- def _insert_data(self, docs):
- for doc in docs:
- chunk = doc.get("content")
- insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc)
- if not insert_flag:
- break
- resp = insert_flag
- return resp, insert_info
-
- def _batch_insert_data(self, docs, text_lists):
- insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs)
- resp = insert_flag
- return resp, insert_info
- def _search(self, query, k, mode):
- search_result = self.hybrid_retriever.search(query, k, mode)
- return search_result
- class MysqlOperate:
- def get_connection(self):
- """
- 从连接池中获取一个连接
- :return: 数据库连接对象
- """
- # try:
- # with ThreadPoolExecutor() as executor:
- # future = executor.submit(POOL.get_connection)
- # connection = future.result(timeout=5.0) # 设置超时时间为5秒
- # logger.info("成功从连接池获取连接")
- # return connection, "success"
- # except TimeoutError:
- # logger.error("获取mysql数据库连接池超时")
- # return None, "mysql获取连接池超时"
- # except errors.InterfaceError as e:
- # logger.error(f"MySQL 接口异常:{e}")
- # return None, "mysql接口异常"
- # except errors.OperationalError as e:
- # logger.error(f"MySQL 操作错误:{e}")
- # return None, "mysql 操作错误"
- # except Error as e:
- # logger.error(f"无法从连接池获取连接: {e}")
- # return None, str(e)
- connection = None
- event = threading.Event()
- def target():
- nonlocal connection
- try:
- connection = POOL.get_connection()
- finally:
- event.set()
- thread = threading.Thread(target=target)
- thread.start()
- event.wait(timeout=5)
- if thread.is_alive():
- # 超时处理
- logger.error("获取连接超时")
- return None, "获取连接超时"
- else:
- if connection:
- return connection, "success"
- else:
- logger.error("获取连接失败")
- return None, "获取连接失败"
- def insert_to_slice(self, docs, knowledge_id, doc_id):
- """
- 插入数据到切片信息表中 slice_info
- """
- connection = None
- cursor = None
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- values = []
- connection, cennction_info = self.get_connection()
- if not connection:
- return False, cennction_info
-
- for chunk in docs:
- slice_id = chunk.get("chunk_id")
- slice_text = chunk.get("content")
- chunk_index = chunk.get("metadata").get("chunk_index")
- values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index))
- try:
- cursor = connection.cursor()
- insert_sql = """
- INSERT INTO slice_info (
- slice_id,
- knowledge_id,
- document_id,
- slice_text,
- create_time,
- slice_index
- ) VALUES (%s, %s, %s, %s, %s,%s)
- """
-
- cursor.executemany(insert_sql, values)
- connection.commit()
- logger.info(f"批量插入切片数据成功。")
- return True, "success"
- except Error as e:
- logger.error(f"数据库操作出错:{e}")
- connection.rollback()
- return False, str(e)
- finally:
- # if cursor:
- cursor.close()
- # if connection and connection.is_connected():
- connection.close()
- def delete_to_slice(self, doc_id):
- """
- 删除 slice_info库中切片信息
- """
- connection = None
- cursor = None
- connection, connection_info = self.get_connection()
- if not connection:
- return False, connection_info
- try:
- cursor = connection.cursor()
- delete_sql = f"DELETE FROM slice_info WHERE document_id = %s"
- cursor.execute(delete_sql, (doc_id,))
- connection.commit()
- logger.info(f"删除数据成功")
- return True, "success"
- except Error as e:
- logger.error(f"根据{doc_id}删除数据失败:{e}")
- connection.rollback()
- return False, str(e)
- finally:
- # if cursor:
- cursor.close()
- # if connection and connection.is_connected():
- connection.close()
- def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
- """
- 批量插入数据到指定表
- """
- connection = None
- cursor = None
- connection, connection_info = self.get_connection()
- if not connection:
- return False, connection_info
-
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- values = []
- for img_key, img_value in image_dict.items():
- origin_text = img_key
- media_url = img_value
- values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now))
- try:
- cursor = connection.cursor()
- insert_sql = """
- INSERT INTO bm_media_replacement (
- knowledge_id,
- document_id,
- origin_text,
- media_type,
- media_url,
- create_time
- ) VALUES (%s, %s, %s, %s, %s, %s)
- """
- cursor.executemany(insert_sql, values)
- connection.commit()
- logger.info(f"插入到bm_media_replacement表成功")
- return True, "success"
- except Error as e:
- logger.error(f"数据库操作出错:{e}")
- connection.rollback()
- return False, str(e)
- finally:
- # if cursor:
- cursor.close()
- # if connection and connection.is_connected():
- connection.close()
- def delete_image_url(self, doc_id):
- """
- 根据doc id删除bm_media_replacement中的数据
- """
- connection = None
- cursor = None
- connection, connection_info = self.get_connection()
- if not connection:
- return False, connection_info
-
- try:
- cursor = connection.cursor()
- delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s"
- cursor.execute(delete_sql, (doc_id,))
- connection.commit()
- logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功")
- return True, "success"
- except Error as e:
- logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}")
- connection.rollback()
- return False, str(e)
- finally:
- # if cursor:
- cursor.close()
- # if connection and connection.is_connected():
- connection.close()
- def update_total_doc_len(self, update_json):
- """
- 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息
- """
- knowledge_id = update_json.get("knowledge_id")
- doc_id = update_json.get("doc_id")
- chunk_len = update_json.get("chunk_len")
- operate = update_json.get("operate")
- chunk_id = update_json.get("chunk_id")
- chunk_text = update_json.get("chunk_text")
- connection = None
- cursor = None
- connection, connection_info = self.get_connection()
- if not connection:
- return False, connection_info
- try:
- cursor = connection.cursor()
- query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s"
- query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s"
- cursor.execute(query_doc_word_num_sql, (doc_id,))
- doc_result = cursor.fetchone()
- logger.info(f"查询到的文档长度信息:{doc_result}")
- cursor.execute(query_knowledge_word_num_sql, (knowledge_id, ))
- knowledge_result = cursor.fetchone()
- logger.info(f"查询到的知识库总长度信息:{knowledge_result}")
- if not doc_result:
- new_word_num = 0
- slice_total = 0
- else:
- old_word_num = doc_result[0]
- slice_total = doc_result[1]
- new_word_num = old_word_num + chunk_len
- slice_total -= 1 if slice_total else 0
- if not knowledge_result:
- new_knowledge_word_num = 0
- else:
- old_knowledge_word_num = knowledge_result[0]
- new_knowledge_word_num = old_knowledge_word_num + chunk_len
- if operate == "update":
- update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
- cursor.execute(update_sql, (new_word_num, doc_id))
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s"
- cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id))
- elif operate == "insert":
- query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s"
- cursor.execute(query_slice_info_index_sql, (doc_id,))
- chunk_index_result = cursor.fetchone()[0]
- # logger.info(chunk_index_result)
- if chunk_index_result:
- chunk_max_index = int(chunk_index_result)
- else:
- chunk_max_index = 0
- update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
- cursor.execute(update_sql, (new_word_num, doc_id))
- date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)"
- cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1))
- else:
- update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s"
- cursor.execute(update_sql, (new_word_num, slice_total, doc_id))
- # 删除切片id对应的切片
- delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s"
- cursor.execute(delete_slice_sql, (chunk_id, ))
- update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s"
- cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id))
- connection.commit()
- logger.info("bm_document和bm_knowledge数据更新成功")
- return True, "success"
- except Error as e:
- logger.error(f"数据库操作出错:{e}")
- connection.rollback()
- return False, str(e)
- finally:
- # if cursor:
- cursor.close()
- # if connection and connection.is_connected():
- connection.close()
|