import re import asyncio import aiofiles from typing import List, Tuple from utils.get_logger import setup_logger logger = setup_logger(__name__) class MarkdownAtom: """Markdown 原子类,表示一个不可分割的内容单元""" def __init__(self, content: str, atom_type: str, level: int = 0, title: str = ""): self.content = content self.atom_type = atom_type # 'title', 'table', 'placeholder', 'text' self.level = level self.title = title def __len__(self): return len(self.content) def __repr__(self): if self.atom_type == "title": return f"" return f"" class MarkdownSplitter: """Markdown 文件智能切分器""" def __init__(self, max_chunk_size: int = 5000, overlap_ratio: float = 0.1): self.max_chunk_size = max_chunk_size self.overlap_ratio = overlap_ratio async def read_markdown_file(self, file_path: str) -> str: """读取 Markdown 文件内容""" try: async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: content = await f.read() logger.info(f"读取文件: {file_path}, 长度: {len(content)}") return content except Exception as e: logger.error(f"读取文件失败 [{file_path}]: {e}") raise def _extract_atoms_sync(self, content: str) -> List[MarkdownAtom]: """同步提取原子""" atoms = [] # 先提取 HTML 表格,保留原始格式 html_tables = [] html_placeholder = "<<>>" def replace_html_table(match): idx = len(html_tables) # 保留原始 HTML 表格格式 html_tables.append(match.group(0)) return html_placeholder.format(idx) processed = re.sub(r'', replace_html_table, content, flags=re.IGNORECASE) # 按行处理 lines = processed.split('\n') i = 0 while i < len(lines): line = lines[i] # 1. 处理 HTML 表格占位符 html_match = re.search(r'<<>>', line) if html_match: table_idx = int(html_match.group(1)) if line.strip() == html_match.group(0): # 单独一行 atoms.append(MarkdownAtom(html_tables[table_idx] + '\n', 'table')) else: # 混合内容,分段处理 before = line[:html_match.start()] after = line[html_match.end():] if before.strip(): atoms.append(MarkdownAtom(before, 'text')) atoms.append(MarkdownAtom(html_tables[table_idx], 'table')) if after.strip(): atoms.append(MarkdownAtom(after + '\n', 'text')) i += 1 continue # 2. 处理标题 title_match = re.match(r'^(#{1,6})\s+(.+)$', line) if title_match: level = len(title_match.group(1)) title_text = title_match.group(2).strip() atoms.append(MarkdownAtom(line + '\n', 'title', level, title_text)) i += 1 continue # 3. 处理 Markdown 表格 if line.strip().startswith('|') and line.strip().endswith('|'): table_lines = [] while i < len(lines) and lines[i].strip().startswith('|') and lines[i].strip().endswith('|'): table_lines.append(lines[i]) i += 1 atoms.append(MarkdownAtom('\n'.join(table_lines) + '\n', 'table')) continue # 4. 处理占位符 【...】 if '【' in line and '】' in line: j = 0 temp_text = [] while j < len(line): if line[j] == '【': if temp_text: atoms.append(MarkdownAtom(''.join(temp_text), 'text')) temp_text = [] end = line.find('】', j) if end != -1: atoms.append(MarkdownAtom(line[j:end+1], 'placeholder')) j = end + 1 else: temp_text.append(line[j]) j += 1 else: temp_text.append(line[j]) j += 1 if temp_text: atoms.append(MarkdownAtom(''.join(temp_text) + '\n', 'text')) i += 1 continue # 5. 普通文本 atoms.append(MarkdownAtom(line + '\n', 'text')) i += 1 logger.info(f"提取原子: {len(atoms)} 个 (含 {len(html_tables)} 个 HTML 表格)") return atoms async def extract_atoms(self, content: str) -> List[MarkdownAtom]: """提取原子:标题、表格、占位符、文本""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._extract_atoms_sync, content) async def split_large_atom(self, atom: MarkdownAtom) -> List[MarkdownAtom]: """切分超长原子""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._split_large_atom_sync, atom) def _split_large_atom_sync(self, atom: MarkdownAtom) -> List[MarkdownAtom]: """切分超长原子""" if len(atom) <= self.max_chunk_size: return [atom] logger.warning(f"原子超长: {atom}, 进行切分") result = [] content = atom.content if atom.atom_type == "table": # 表格按行切分 if ']*>)(.*)()', content, re.DOTALL | re.IGNORECASE) if match: opener, body, closer = match.groups() rows = re.findall(r']*>.*?', body, re.DOTALL | re.IGNORECASE) chunk = opener for row in rows: if len(chunk) + len(row) + len(closer) > self.max_chunk_size: if chunk != opener: result.append(MarkdownAtom(chunk + closer, 'table')) chunk = opener chunk += row if chunk != opener: result.append(MarkdownAtom(chunk + closer, 'table')) else: # 按字符强制切 for i in range(0, len(content), self.max_chunk_size): result.append(MarkdownAtom(content[i:i+self.max_chunk_size], 'table')) else: # Markdown 表格按行切 lines = content.split('\n') chunk_lines = [] chunk_len = 0 for line in lines: line_len = len(line) + 1 if chunk_len + line_len > self.max_chunk_size and chunk_lines: result.append(MarkdownAtom('\n'.join(chunk_lines) + '\n', 'table')) chunk_lines = [] chunk_len = 0 chunk_lines.append(line) chunk_len += line_len if chunk_lines: result.append(MarkdownAtom('\n'.join(chunk_lines) + '\n', 'table')) else: # 文本/标题/占位符按句号切分 sentences = re.split(r'([。!?\n])', content) combined = [] for i in range(0, len(sentences), 2): if i + 1 < len(sentences): combined.append(sentences[i] + sentences[i+1]) else: combined.append(sentences[i]) chunk = [] chunk_len = 0 for sent in combined: if not sent: continue sent_len = len(sent) if chunk_len + sent_len > self.max_chunk_size and chunk: result.append(MarkdownAtom(''.join(chunk), atom.atom_type, atom.level, atom.title)) chunk = [] chunk_len = 0 chunk.append(sent) chunk_len += sent_len if chunk: result.append(MarkdownAtom(''.join(chunk), atom.atom_type, atom.level, atom.title)) logger.info(f"切分完成: {len(atom)} -> {len(result)} 块") return result def calc_overlap_atoms(self, atoms: List[MarkdownAtom]) -> List[MarkdownAtom]: """计算重叠原子(从后往前按比例收集)""" if not atoms: return [] total_len = sum(len(a) for a in atoms) target_len = int(total_len * self.overlap_ratio) if target_len == 0: return [] overlap = [] current_len = 0 for atom in reversed(atoms): overlap.insert(0, atom) current_len += len(atom) if current_len >= target_len: break return overlap async def smart_chunk_atoms(self, atoms: List[MarkdownAtom]) -> List[str]: """ 智能组装原子为切片,保护表格/占位符完整性 仅在超长需要切分时添加重叠 """ # 先处理超长原子 processed = [] for atom in atoms: if len(atom) > self.max_chunk_size: processed.extend(await self.split_large_atom(atom)) else: processed.append(atom) chunks = [] current_atoms = [] current_len = 0 i = 0 while i < len(processed): atom = processed[i] atom_len = len(atom) # 尝试保护表格/占位符完整性 if atom.atom_type in ('table', 'placeholder'): # 如果当前块 + 该原子会超长 if current_len > 0 and current_len + atom_len > self.max_chunk_size: # 检查该原子是否能单独成块 if atom_len <= self.max_chunk_size: # 保存当前块 chunks.append(''.join(a.content for a in current_atoms)) # 计算重叠(仅在切分时) overlap = self.calc_overlap_atoms(current_atoms) # 新块包含重叠 + 该原子 current_atoms = overlap + [atom] current_len = sum(len(a) for a in current_atoms) i += 1 continue # 否则该原子太大,正常处理(会被切分) # 正常流程:检查是否超长 if current_len > 0 and current_len + atom_len > self.max_chunk_size: # 保存当前块 chunks.append(''.join(a.content for a in current_atoms)) # 计算重叠 overlap = self.calc_overlap_atoms(current_atoms) # 开始新块 current_atoms = overlap current_len = sum(len(a) for a in current_atoms) # 添加原子 current_atoms.append(atom) current_len += atom_len i += 1 # 保存最后一块 if current_atoms: chunks.append(''.join(a.content for a in current_atoms)) return chunks async def split_by_min_paragraph(self, atoms: List[MarkdownAtom]) -> List[str]: """ 最小段落模式:每个标题单独成块,包含父级标题链 + 当前标题 + 内容(内容可无) 优先按一级标题分组,若无一级标题则按二级标题分组 """ # 统计标题级别 level1_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 1) level2_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 2) # 确定顶级标题级别 if level1_count > 0: top_level = 1 logger.info(f"检测到 {level1_count} 个一级标题,按一级标题分组") elif level2_count > 0: top_level = 2 logger.info(f"无一级标题,检测到 {level2_count} 个二级标题,按二级标题分组") else: top_level = None logger.info("无一级或二级标题,按全文处理") chunks = [] i = 0 title_chain = [] # 父级标题链 while i < len(atoms): atom = atoms[i] if not atom.content.strip(): i += 1 continue # 处理标题 if atom.atom_type == 'title': # 顶级标题:重置标题链 if top_level and atom.level == top_level: title_chain = [] # 更新标题链:保留比当前标题级别低的父级标题 parent_chain = [t for t in title_chain if t.level < atom.level] current_title_chain = parent_chain + [atom] # 收集该标题后的内容(直到下一个标题) i += 1 content_atoms = [] while i < len(atoms) and atoms[i].atom_type != 'title': content_atoms.append(atoms[i]) i += 1 # 组装切片:父级标题链 + 当前标题 + 内容 full_block = current_title_chain + content_atoms block_len = sum(len(a) for a in full_block) if block_len <= self.max_chunk_size and full_block: chunks.append(''.join(a.content for a in full_block)) elif full_block: logger.info(f"段落超长 ({block_len} 字符),智能切分") sub_chunks = await self.smart_chunk_atoms(full_block) chunks.extend(sub_chunks) # 更新标题链为当前完整链(用于下一个标题的父级链) title_chain = current_title_chain else: # 非标题内容(没有标题的情况) content_atoms = [] while i < len(atoms) and atoms[i].atom_type != 'title': content_atoms.append(atoms[i]) i += 1 if content_atoms: full_block = title_chain + content_atoms block_len = sum(len(a) for a in full_block) if block_len <= self.max_chunk_size and full_block: chunks.append(''.join(a.content for a in full_block)) elif full_block: logger.info(f"段落超长 ({block_len} 字符),智能切分") sub_chunks = await self.smart_chunk_atoms(full_block) chunks.extend(sub_chunks) logger.info(f"最小段落切分: {len(chunks)} 个切片") return chunks async def split_by_max_paragraph(self, atoms: List[MarkdownAtom]) -> List[str]: """ 最大段落模式:自适应选择切分级别 优先按一级标题切分,若无一级标题则按二级标题切分 不超长直接保存,超长则智能切分(带重叠) """ # 统计标题级别 level1_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 1) level2_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 2) # 动态选择切分级别 if level1_count > 0: split_level = 1 logger.info(f"检测到 {level1_count} 个一级标题,按一级标题切分") elif level2_count > 0: split_level = 2 logger.info(f"无一级标题,检测到 {level2_count} 个二级标题,按二级标题切分") else: split_level = None logger.info("无一级或二级标题,按全文处理") chunks = [] i = 0 while i < len(atoms): block_atoms = [] # 收集到下一个目标级别标题之前的所有原子 while i < len(atoms): block_atoms.append(atoms[i]) i += 1 # 如果下一个是目标级别标题,停止 if split_level and i < len(atoms) and atoms[i].atom_type == 'title' and atoms[i].level == split_level: break block_len = sum(len(a) for a in block_atoms) # 不超长,直接保存(不重叠) if block_len <= self.max_chunk_size: chunks.append(''.join(a.content for a in block_atoms)) else: # 超长,智能切分(带重叠) logger.info(f"章节超长 ({block_len} 字符),智能切分") sub_chunks = await self.smart_chunk_atoms(block_atoms) chunks.extend(sub_chunks) logger.info(f"最大段落切分: {len(chunks)} 个切片") return chunks async def split_markdown(self, file_path: str, split_mode: str = "min") -> List[str]: """ 切分 Markdown 文件 参数: file_path: MD 文件路径 split_mode: 'min' 最小段落(标题到标题), 'max' 最大段落(一级标题到一级标题) 返回: List[str]: 切片列表 """ logger.info(f"开始切分: {file_path}, 模式: {split_mode}") content = await self.read_markdown_file(file_path) atoms = await self.extract_atoms(content) if not atoms: logger.warning("未提取到原子") return [] if split_mode == "min": return await self.split_by_min_paragraph(atoms) elif split_mode == "max": return await self.split_by_max_paragraph(atoms) else: raise ValueError(f"不支持的模式: {split_mode},请使用 'min' 或 'max'") # 使用示例 if __name__ == "__main__": async def main(): splitter = MarkdownSplitter(max_chunk_size=5000, overlap_ratio=0.1) test_file = "test.md" # 最小段落模式 chunks_min = await splitter.split_markdown(test_file, "min") print(f"最小段落: {len(chunks_min)} 个切片") for i, chunk in enumerate(chunks_min[:3], 1): print(f"\n--- 切片 {i} ({len(chunk)} 字符) ---") print(chunk[:200] + "..." if len(chunk) > 200 else chunk) # 最大段落模式 chunks_max = await splitter.split_markdown(test_file, "max") print(f"\n最大段落: {len(chunks_max)} 个切片") for i, chunk in enumerate(chunks_max[:3], 1): print(f"\n--- 切片 {i} ({len(chunk)} 字符) ---") print(chunk[:200] + "..." if len(chunk) > 200 else chunk) asyncio.run(main())