md_splitter.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. import re
  2. import asyncio
  3. import aiofiles
  4. from typing import List, Tuple
  5. from utils.get_logger import setup_logger
  6. logger = setup_logger(__name__)
  7. class MarkdownAtom:
  8. """Markdown 原子类,表示一个不可分割的内容单元"""
  9. def __init__(self, content: str, atom_type: str, level: int = 0, title: str = ""):
  10. self.content = content
  11. self.atom_type = atom_type # 'title', 'table', 'placeholder', 'text'
  12. self.level = level
  13. self.title = title
  14. def __len__(self):
  15. return len(self.content)
  16. def __repr__(self):
  17. if self.atom_type == "title":
  18. return f"<Atom {self.atom_type} L{self.level} len={len(self)} '{self.title[:20]}'>"
  19. return f"<Atom {self.atom_type} len={len(self)}>"
  20. class MarkdownSplitter:
  21. """Markdown 文件智能切分器"""
  22. def __init__(self, max_chunk_size: int = 5000, overlap_ratio: float = 0.1):
  23. self.max_chunk_size = max_chunk_size
  24. self.overlap_ratio = overlap_ratio
  25. async def read_markdown_file(self, file_path: str) -> str:
  26. """读取 Markdown 文件内容"""
  27. try:
  28. async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
  29. content = await f.read()
  30. logger.info(f"读取文件: {file_path}, 长度: {len(content)}")
  31. return content
  32. except Exception as e:
  33. logger.error(f"读取文件失败 [{file_path}]: {e}")
  34. raise
  35. def _extract_atoms_sync(self, content: str) -> List[MarkdownAtom]:
  36. """同步提取原子"""
  37. atoms = []
  38. # 先提取 HTML 表格,保留原始格式
  39. html_tables = []
  40. html_placeholder = "<<<HTML_TABLE_{}>>>"
  41. def replace_html_table(match):
  42. idx = len(html_tables)
  43. # 保留原始 HTML 表格格式
  44. html_tables.append(match.group(0))
  45. return html_placeholder.format(idx)
  46. processed = re.sub(r'<table[\s\S]*?</table>', replace_html_table, content, flags=re.IGNORECASE)
  47. # 按行处理
  48. lines = processed.split('\n')
  49. i = 0
  50. while i < len(lines):
  51. line = lines[i]
  52. # 1. 处理 HTML 表格占位符
  53. html_match = re.search(r'<<<HTML_TABLE_(\d+)>>>', line)
  54. if html_match:
  55. table_idx = int(html_match.group(1))
  56. if line.strip() == html_match.group(0):
  57. # 单独一行
  58. atoms.append(MarkdownAtom(html_tables[table_idx] + '\n', 'table'))
  59. else:
  60. # 混合内容,分段处理
  61. before = line[:html_match.start()]
  62. after = line[html_match.end():]
  63. if before.strip():
  64. atoms.append(MarkdownAtom(before, 'text'))
  65. atoms.append(MarkdownAtom(html_tables[table_idx], 'table'))
  66. if after.strip():
  67. atoms.append(MarkdownAtom(after + '\n', 'text'))
  68. i += 1
  69. continue
  70. # 2. 处理标题
  71. title_match = re.match(r'^(#{1,6})\s+(.+)$', line)
  72. if title_match:
  73. level = len(title_match.group(1))
  74. title_text = title_match.group(2).strip()
  75. atoms.append(MarkdownAtom(line + '\n', 'title', level, title_text))
  76. i += 1
  77. continue
  78. # 3. 处理 Markdown 表格
  79. if line.strip().startswith('|') and line.strip().endswith('|'):
  80. table_lines = []
  81. while i < len(lines) and lines[i].strip().startswith('|') and lines[i].strip().endswith('|'):
  82. table_lines.append(lines[i])
  83. i += 1
  84. atoms.append(MarkdownAtom('\n'.join(table_lines) + '\n', 'table'))
  85. continue
  86. # 4. 处理占位符 【...】
  87. if '【' in line and '】' in line:
  88. j = 0
  89. temp_text = []
  90. while j < len(line):
  91. if line[j] == '【':
  92. if temp_text:
  93. atoms.append(MarkdownAtom(''.join(temp_text), 'text'))
  94. temp_text = []
  95. end = line.find('】', j)
  96. if end != -1:
  97. atoms.append(MarkdownAtom(line[j:end+1], 'placeholder'))
  98. j = end + 1
  99. else:
  100. temp_text.append(line[j])
  101. j += 1
  102. else:
  103. temp_text.append(line[j])
  104. j += 1
  105. if temp_text:
  106. atoms.append(MarkdownAtom(''.join(temp_text) + '\n', 'text'))
  107. i += 1
  108. continue
  109. # 5. 普通文本
  110. atoms.append(MarkdownAtom(line + '\n', 'text'))
  111. i += 1
  112. logger.info(f"提取原子: {len(atoms)} 个 (含 {len(html_tables)} 个 HTML 表格)")
  113. return atoms
  114. async def extract_atoms(self, content: str) -> List[MarkdownAtom]:
  115. """提取原子:标题、表格、占位符、文本"""
  116. loop = asyncio.get_event_loop()
  117. return await loop.run_in_executor(None, self._extract_atoms_sync, content)
  118. async def split_large_atom(self, atom: MarkdownAtom) -> List[MarkdownAtom]:
  119. """切分超长原子"""
  120. loop = asyncio.get_event_loop()
  121. return await loop.run_in_executor(None, self._split_large_atom_sync, atom)
  122. def _split_large_atom_sync(self, atom: MarkdownAtom) -> List[MarkdownAtom]:
  123. """切分超长原子"""
  124. if len(atom) <= self.max_chunk_size:
  125. return [atom]
  126. logger.warning(f"原子超长: {atom}, 进行切分")
  127. result = []
  128. content = atom.content
  129. if atom.atom_type == "table":
  130. # 表格按行切分
  131. if '<table' in content.lower():
  132. # HTML 表格
  133. match = re.match(r'(<table[^>]*>)(.*)(</table>)', content, re.DOTALL | re.IGNORECASE)
  134. if match:
  135. opener, body, closer = match.groups()
  136. rows = re.findall(r'<tr[^>]*>.*?</tr>', body, re.DOTALL | re.IGNORECASE)
  137. chunk = opener
  138. for row in rows:
  139. if len(chunk) + len(row) + len(closer) > self.max_chunk_size:
  140. if chunk != opener:
  141. result.append(MarkdownAtom(chunk + closer, 'table'))
  142. chunk = opener
  143. chunk += row
  144. if chunk != opener:
  145. result.append(MarkdownAtom(chunk + closer, 'table'))
  146. else:
  147. # 按字符强制切
  148. for i in range(0, len(content), self.max_chunk_size):
  149. result.append(MarkdownAtom(content[i:i+self.max_chunk_size], 'table'))
  150. else:
  151. # Markdown 表格按行切
  152. lines = content.split('\n')
  153. chunk_lines = []
  154. chunk_len = 0
  155. for line in lines:
  156. line_len = len(line) + 1
  157. if chunk_len + line_len > self.max_chunk_size and chunk_lines:
  158. result.append(MarkdownAtom('\n'.join(chunk_lines) + '\n', 'table'))
  159. chunk_lines = []
  160. chunk_len = 0
  161. chunk_lines.append(line)
  162. chunk_len += line_len
  163. if chunk_lines:
  164. result.append(MarkdownAtom('\n'.join(chunk_lines) + '\n', 'table'))
  165. else:
  166. # 文本/标题/占位符按句号切分
  167. sentences = re.split(r'([。!?\n])', content)
  168. combined = []
  169. for i in range(0, len(sentences), 2):
  170. if i + 1 < len(sentences):
  171. combined.append(sentences[i] + sentences[i+1])
  172. else:
  173. combined.append(sentences[i])
  174. chunk = []
  175. chunk_len = 0
  176. for sent in combined:
  177. if not sent:
  178. continue
  179. sent_len = len(sent)
  180. if chunk_len + sent_len > self.max_chunk_size and chunk:
  181. result.append(MarkdownAtom(''.join(chunk), atom.atom_type, atom.level, atom.title))
  182. chunk = []
  183. chunk_len = 0
  184. chunk.append(sent)
  185. chunk_len += sent_len
  186. if chunk:
  187. result.append(MarkdownAtom(''.join(chunk), atom.atom_type, atom.level, atom.title))
  188. logger.info(f"切分完成: {len(atom)} -> {len(result)} 块")
  189. return result
  190. def calc_overlap_atoms(self, atoms: List[MarkdownAtom]) -> List[MarkdownAtom]:
  191. """计算重叠原子(从后往前按比例收集)"""
  192. if not atoms:
  193. return []
  194. total_len = sum(len(a) for a in atoms)
  195. target_len = int(total_len * self.overlap_ratio)
  196. if target_len == 0:
  197. return []
  198. overlap = []
  199. current_len = 0
  200. for atom in reversed(atoms):
  201. overlap.insert(0, atom)
  202. current_len += len(atom)
  203. if current_len >= target_len:
  204. break
  205. return overlap
  206. async def smart_chunk_atoms(self, atoms: List[MarkdownAtom]) -> List[str]:
  207. """
  208. 智能组装原子为切片,保护表格/占位符完整性
  209. 仅在超长需要切分时添加重叠
  210. """
  211. # 先处理超长原子
  212. processed = []
  213. for atom in atoms:
  214. if len(atom) > self.max_chunk_size:
  215. processed.extend(await self.split_large_atom(atom))
  216. else:
  217. processed.append(atom)
  218. chunks = []
  219. current_atoms = []
  220. current_len = 0
  221. i = 0
  222. while i < len(processed):
  223. atom = processed[i]
  224. atom_len = len(atom)
  225. # 尝试保护表格/占位符完整性
  226. if atom.atom_type in ('table', 'placeholder'):
  227. # 如果当前块 + 该原子会超长
  228. if current_len > 0 and current_len + atom_len > self.max_chunk_size:
  229. # 检查该原子是否能单独成块
  230. if atom_len <= self.max_chunk_size:
  231. # 保存当前块
  232. chunks.append(''.join(a.content for a in current_atoms))
  233. # 计算重叠(仅在切分时)
  234. overlap = self.calc_overlap_atoms(current_atoms)
  235. # 新块包含重叠 + 该原子
  236. current_atoms = overlap + [atom]
  237. current_len = sum(len(a) for a in current_atoms)
  238. i += 1
  239. continue
  240. # 否则该原子太大,正常处理(会被切分)
  241. # 正常流程:检查是否超长
  242. if current_len > 0 and current_len + atom_len > self.max_chunk_size:
  243. # 保存当前块
  244. chunks.append(''.join(a.content for a in current_atoms))
  245. # 计算重叠
  246. overlap = self.calc_overlap_atoms(current_atoms)
  247. # 开始新块
  248. current_atoms = overlap
  249. current_len = sum(len(a) for a in current_atoms)
  250. # 添加原子
  251. current_atoms.append(atom)
  252. current_len += atom_len
  253. i += 1
  254. # 保存最后一块
  255. if current_atoms:
  256. chunks.append(''.join(a.content for a in current_atoms))
  257. return chunks
  258. async def split_by_min_paragraph(self, atoms: List[MarkdownAtom]) -> List[str]:
  259. """
  260. 最小段落模式:每个标题单独成块,包含父级标题链 + 当前标题 + 内容(内容可无)
  261. 优先按一级标题分组,若无一级标题则按二级标题分组
  262. """
  263. # 统计标题级别
  264. level1_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 1)
  265. level2_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 2)
  266. # 确定顶级标题级别
  267. if level1_count > 0:
  268. top_level = 1
  269. logger.info(f"检测到 {level1_count} 个一级标题,按一级标题分组")
  270. elif level2_count > 0:
  271. top_level = 2
  272. logger.info(f"无一级标题,检测到 {level2_count} 个二级标题,按二级标题分组")
  273. else:
  274. top_level = None
  275. logger.info("无一级或二级标题,按全文处理")
  276. chunks = []
  277. i = 0
  278. title_chain = [] # 父级标题链
  279. while i < len(atoms):
  280. atom = atoms[i]
  281. if not atom.content.strip():
  282. i += 1
  283. continue
  284. # 处理标题
  285. if atom.atom_type == 'title':
  286. # 顶级标题:重置标题链
  287. if top_level and atom.level == top_level:
  288. title_chain = []
  289. # 更新标题链:保留比当前标题级别低的父级标题
  290. parent_chain = [t for t in title_chain if t.level < atom.level]
  291. current_title_chain = parent_chain + [atom]
  292. # 收集该标题后的内容(直到下一个标题)
  293. i += 1
  294. content_atoms = []
  295. while i < len(atoms) and atoms[i].atom_type != 'title':
  296. content_atoms.append(atoms[i])
  297. i += 1
  298. # 组装切片:父级标题链 + 当前标题 + 内容
  299. full_block = current_title_chain + content_atoms
  300. block_len = sum(len(a) for a in full_block)
  301. if block_len <= self.max_chunk_size and full_block:
  302. chunks.append(''.join(a.content for a in full_block))
  303. elif full_block:
  304. logger.info(f"段落超长 ({block_len} 字符),智能切分")
  305. sub_chunks = await self.smart_chunk_atoms(full_block)
  306. chunks.extend(sub_chunks)
  307. # 更新标题链为当前完整链(用于下一个标题的父级链)
  308. title_chain = current_title_chain
  309. else:
  310. # 非标题内容(没有标题的情况)
  311. content_atoms = []
  312. while i < len(atoms) and atoms[i].atom_type != 'title':
  313. content_atoms.append(atoms[i])
  314. i += 1
  315. if content_atoms:
  316. full_block = title_chain + content_atoms
  317. block_len = sum(len(a) for a in full_block)
  318. if block_len <= self.max_chunk_size and full_block:
  319. chunks.append(''.join(a.content for a in full_block))
  320. elif full_block:
  321. logger.info(f"段落超长 ({block_len} 字符),智能切分")
  322. sub_chunks = await self.smart_chunk_atoms(full_block)
  323. chunks.extend(sub_chunks)
  324. logger.info(f"最小段落切分: {len(chunks)} 个切片")
  325. return chunks
  326. async def split_by_max_paragraph(self, atoms: List[MarkdownAtom]) -> List[str]:
  327. """
  328. 最大段落模式:自适应选择切分级别
  329. 优先按一级标题切分,若无一级标题则按二级标题切分
  330. 不超长直接保存,超长则智能切分(带重叠)
  331. """
  332. # 统计标题级别
  333. level1_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 1)
  334. level2_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 2)
  335. # 动态选择切分级别
  336. if level1_count > 0:
  337. split_level = 1
  338. logger.info(f"检测到 {level1_count} 个一级标题,按一级标题切分")
  339. elif level2_count > 0:
  340. split_level = 2
  341. logger.info(f"无一级标题,检测到 {level2_count} 个二级标题,按二级标题切分")
  342. else:
  343. split_level = None
  344. logger.info("无一级或二级标题,按全文处理")
  345. chunks = []
  346. i = 0
  347. while i < len(atoms):
  348. block_atoms = []
  349. # 收集到下一个目标级别标题之前的所有原子
  350. while i < len(atoms):
  351. block_atoms.append(atoms[i])
  352. i += 1
  353. # 如果下一个是目标级别标题,停止
  354. if split_level and i < len(atoms) and atoms[i].atom_type == 'title' and atoms[i].level == split_level:
  355. break
  356. block_len = sum(len(a) for a in block_atoms)
  357. # 不超长,直接保存(不重叠)
  358. if block_len <= self.max_chunk_size:
  359. chunks.append(''.join(a.content for a in block_atoms))
  360. else:
  361. # 超长,智能切分(带重叠)
  362. logger.info(f"章节超长 ({block_len} 字符),智能切分")
  363. sub_chunks = await self.smart_chunk_atoms(block_atoms)
  364. chunks.extend(sub_chunks)
  365. logger.info(f"最大段落切分: {len(chunks)} 个切片")
  366. return chunks
  367. async def split_markdown(self, file_path: str, split_mode: str = "min") -> List[str]:
  368. """
  369. 切分 Markdown 文件
  370. 参数:
  371. file_path: MD 文件路径
  372. split_mode: 'min' 最小段落(标题到标题), 'max' 最大段落(一级标题到一级标题)
  373. 返回:
  374. List[str]: 切片列表
  375. """
  376. logger.info(f"开始切分: {file_path}, 模式: {split_mode}")
  377. content = await self.read_markdown_file(file_path)
  378. atoms = await self.extract_atoms(content)
  379. if not atoms:
  380. logger.warning("未提取到原子")
  381. return []
  382. if split_mode == "min":
  383. return await self.split_by_min_paragraph(atoms)
  384. elif split_mode == "max":
  385. return await self.split_by_max_paragraph(atoms)
  386. else:
  387. raise ValueError(f"不支持的模式: {split_mode},请使用 'min' 或 'max'")
  388. # 使用示例
  389. if __name__ == "__main__":
  390. async def main():
  391. splitter = MarkdownSplitter(max_chunk_size=5000, overlap_ratio=0.1)
  392. test_file = "test.md"
  393. # 最小段落模式
  394. chunks_min = await splitter.split_markdown(test_file, "min")
  395. print(f"最小段落: {len(chunks_min)} 个切片")
  396. for i, chunk in enumerate(chunks_min[:3], 1):
  397. print(f"\n--- 切片 {i} ({len(chunk)} 字符) ---")
  398. print(chunk[:200] + "..." if len(chunk) > 200 else chunk)
  399. # 最大段落模式
  400. chunks_max = await splitter.split_markdown(test_file, "max")
  401. print(f"\n最大段落: {len(chunks_max)} 个切片")
  402. for i, chunk in enumerate(chunks_max[:3], 1):
  403. print(f"\n--- 切片 {i} ({len(chunk)} 字符) ---")
  404. print(chunk[:200] + "..." if len(chunk) > 200 else chunk)
  405. asyncio.run(main())