| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494 |
- import re
- import asyncio
- import aiofiles
- from typing import List, Tuple
- from utils.get_logger import setup_logger
- logger = setup_logger(__name__)
-
- class MarkdownAtom:
- """Markdown 原子类,表示一个不可分割的内容单元"""
-
- def __init__(self, content: str, atom_type: str, level: int = 0, title: str = ""):
- self.content = content
- self.atom_type = atom_type # 'title', 'table', 'placeholder', 'text'
- self.level = level
- self.title = title
-
- def __len__(self):
- return len(self.content)
-
- def __repr__(self):
- if self.atom_type == "title":
- return f"<Atom {self.atom_type} L{self.level} len={len(self)} '{self.title[:20]}'>"
- return f"<Atom {self.atom_type} len={len(self)}>"
- class MarkdownSplitter:
- """Markdown 文件智能切分器"""
-
- def __init__(self, max_chunk_size: int = 5000, overlap_ratio: float = 0.1):
- self.max_chunk_size = max_chunk_size
- self.overlap_ratio = overlap_ratio
-
- async def read_markdown_file(self, file_path: str) -> str:
- """读取 Markdown 文件内容"""
- try:
- async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
- content = await f.read()
- logger.info(f"读取文件: {file_path}, 长度: {len(content)}")
- return content
- except Exception as e:
- logger.error(f"读取文件失败 [{file_path}]: {e}")
- raise
-
-
-
- def _extract_atoms_sync(self, content: str) -> List[MarkdownAtom]:
- """同步提取原子"""
- atoms = []
-
- # 先提取 HTML 表格,保留原始格式
- html_tables = []
- html_placeholder = "<<<HTML_TABLE_{}>>>"
-
- def replace_html_table(match):
- idx = len(html_tables)
- # 保留原始 HTML 表格格式
- html_tables.append(match.group(0))
- return html_placeholder.format(idx)
-
- processed = re.sub(r'<table[\s\S]*?</table>', replace_html_table, content, flags=re.IGNORECASE)
-
- # 按行处理
- lines = processed.split('\n')
- i = 0
-
- while i < len(lines):
- line = lines[i]
- # 1. 处理 HTML 表格占位符
- html_match = re.search(r'<<<HTML_TABLE_(\d+)>>>', line)
- if html_match:
- table_idx = int(html_match.group(1))
- if line.strip() == html_match.group(0):
- # 单独一行
- atoms.append(MarkdownAtom(html_tables[table_idx] + '\n', 'table'))
- else:
- # 混合内容,分段处理
- before = line[:html_match.start()]
- after = line[html_match.end():]
- if before.strip():
- atoms.append(MarkdownAtom(before, 'text'))
- atoms.append(MarkdownAtom(html_tables[table_idx], 'table'))
- if after.strip():
- atoms.append(MarkdownAtom(after + '\n', 'text'))
- i += 1
- continue
-
- # 2. 处理标题
- title_match = re.match(r'^(#{1,6})\s+(.+)$', line)
- if title_match:
- level = len(title_match.group(1))
- title_text = title_match.group(2).strip()
- atoms.append(MarkdownAtom(line + '\n', 'title', level, title_text))
- i += 1
- continue
-
- # 3. 处理 Markdown 表格
- if line.strip().startswith('|') and line.strip().endswith('|'):
- table_lines = []
- while i < len(lines) and lines[i].strip().startswith('|') and lines[i].strip().endswith('|'):
- table_lines.append(lines[i])
- i += 1
- atoms.append(MarkdownAtom('\n'.join(table_lines) + '\n', 'table'))
- continue
-
- # 4. 处理占位符 【...】
- if '【' in line and '】' in line:
- j = 0
- temp_text = []
- while j < len(line):
- if line[j] == '【':
- if temp_text:
- atoms.append(MarkdownAtom(''.join(temp_text), 'text'))
- temp_text = []
- end = line.find('】', j)
- if end != -1:
- atoms.append(MarkdownAtom(line[j:end+1], 'placeholder'))
- j = end + 1
- else:
- temp_text.append(line[j])
- j += 1
- else:
- temp_text.append(line[j])
- j += 1
- if temp_text:
- atoms.append(MarkdownAtom(''.join(temp_text) + '\n', 'text'))
- i += 1
- continue
-
- # 5. 普通文本
- atoms.append(MarkdownAtom(line + '\n', 'text'))
- i += 1
-
- logger.info(f"提取原子: {len(atoms)} 个 (含 {len(html_tables)} 个 HTML 表格)")
- return atoms
-
- async def extract_atoms(self, content: str) -> List[MarkdownAtom]:
- """提取原子:标题、表格、占位符、文本"""
- loop = asyncio.get_event_loop()
- return await loop.run_in_executor(None, self._extract_atoms_sync, content)
-
- async def split_large_atom(self, atom: MarkdownAtom) -> List[MarkdownAtom]:
- """切分超长原子"""
- loop = asyncio.get_event_loop()
- return await loop.run_in_executor(None, self._split_large_atom_sync, atom)
-
- def _split_large_atom_sync(self, atom: MarkdownAtom) -> List[MarkdownAtom]:
- """切分超长原子"""
- if len(atom) <= self.max_chunk_size:
- return [atom]
-
- logger.warning(f"原子超长: {atom}, 进行切分")
- result = []
- content = atom.content
-
- if atom.atom_type == "table":
- # 表格按行切分
- if '<table' in content.lower():
- # HTML 表格
- match = re.match(r'(<table[^>]*>)(.*)(</table>)', content, re.DOTALL | re.IGNORECASE)
- if match:
- opener, body, closer = match.groups()
- rows = re.findall(r'<tr[^>]*>.*?</tr>', body, re.DOTALL | re.IGNORECASE)
-
- chunk = opener
- for row in rows:
- if len(chunk) + len(row) + len(closer) > self.max_chunk_size:
- if chunk != opener:
- result.append(MarkdownAtom(chunk + closer, 'table'))
- chunk = opener
- chunk += row
-
- if chunk != opener:
- result.append(MarkdownAtom(chunk + closer, 'table'))
- else:
- # 按字符强制切
- for i in range(0, len(content), self.max_chunk_size):
- result.append(MarkdownAtom(content[i:i+self.max_chunk_size], 'table'))
- else:
- # Markdown 表格按行切
- lines = content.split('\n')
- chunk_lines = []
- chunk_len = 0
-
- for line in lines:
- line_len = len(line) + 1
- if chunk_len + line_len > self.max_chunk_size and chunk_lines:
- result.append(MarkdownAtom('\n'.join(chunk_lines) + '\n', 'table'))
- chunk_lines = []
- chunk_len = 0
- chunk_lines.append(line)
- chunk_len += line_len
-
- if chunk_lines:
- result.append(MarkdownAtom('\n'.join(chunk_lines) + '\n', 'table'))
-
- else:
- # 文本/标题/占位符按句号切分
- sentences = re.split(r'([。!?\n])', content)
- combined = []
- for i in range(0, len(sentences), 2):
- if i + 1 < len(sentences):
- combined.append(sentences[i] + sentences[i+1])
- else:
- combined.append(sentences[i])
-
- chunk = []
- chunk_len = 0
- for sent in combined:
- if not sent:
- continue
- sent_len = len(sent)
-
- if chunk_len + sent_len > self.max_chunk_size and chunk:
- result.append(MarkdownAtom(''.join(chunk), atom.atom_type, atom.level, atom.title))
- chunk = []
- chunk_len = 0
-
- chunk.append(sent)
- chunk_len += sent_len
-
- if chunk:
- result.append(MarkdownAtom(''.join(chunk), atom.atom_type, atom.level, atom.title))
-
- logger.info(f"切分完成: {len(atom)} -> {len(result)} 块")
- return result
-
- def calc_overlap_atoms(self, atoms: List[MarkdownAtom]) -> List[MarkdownAtom]:
- """计算重叠原子(从后往前按比例收集)"""
- if not atoms:
- return []
-
- total_len = sum(len(a) for a in atoms)
- target_len = int(total_len * self.overlap_ratio)
-
- if target_len == 0:
- return []
-
- overlap = []
- current_len = 0
-
- for atom in reversed(atoms):
- overlap.insert(0, atom)
- current_len += len(atom)
- if current_len >= target_len:
- break
-
- return overlap
-
- async def smart_chunk_atoms(self, atoms: List[MarkdownAtom]) -> List[str]:
- """
- 智能组装原子为切片,保护表格/占位符完整性
- 仅在超长需要切分时添加重叠
- """
- # 先处理超长原子
- processed = []
- for atom in atoms:
- if len(atom) > self.max_chunk_size:
- processed.extend(await self.split_large_atom(atom))
- else:
- processed.append(atom)
-
- chunks = []
- current_atoms = []
- current_len = 0
-
- i = 0
- while i < len(processed):
- atom = processed[i]
- atom_len = len(atom)
-
- # 尝试保护表格/占位符完整性
- if atom.atom_type in ('table', 'placeholder'):
- # 如果当前块 + 该原子会超长
- if current_len > 0 and current_len + atom_len > self.max_chunk_size:
- # 检查该原子是否能单独成块
- if atom_len <= self.max_chunk_size:
- # 保存当前块
- chunks.append(''.join(a.content for a in current_atoms))
-
- # 计算重叠(仅在切分时)
- overlap = self.calc_overlap_atoms(current_atoms)
-
- # 新块包含重叠 + 该原子
- current_atoms = overlap + [atom]
- current_len = sum(len(a) for a in current_atoms)
- i += 1
- continue
- # 否则该原子太大,正常处理(会被切分)
-
- # 正常流程:检查是否超长
- if current_len > 0 and current_len + atom_len > self.max_chunk_size:
- # 保存当前块
- chunks.append(''.join(a.content for a in current_atoms))
-
- # 计算重叠
- overlap = self.calc_overlap_atoms(current_atoms)
-
- # 开始新块
- current_atoms = overlap
- current_len = sum(len(a) for a in current_atoms)
-
- # 添加原子
- current_atoms.append(atom)
- current_len += atom_len
- i += 1
-
- # 保存最后一块
- if current_atoms:
- chunks.append(''.join(a.content for a in current_atoms))
-
- return chunks
-
- async def split_by_min_paragraph(self, atoms: List[MarkdownAtom]) -> List[str]:
- """
- 最小段落模式:每个标题单独成块,包含父级标题链 + 当前标题 + 内容(内容可无)
- 优先按一级标题分组,若无一级标题则按二级标题分组
- """
- # 统计标题级别
- level1_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 1)
- level2_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 2)
-
- # 确定顶级标题级别
- if level1_count > 0:
- top_level = 1
- logger.info(f"检测到 {level1_count} 个一级标题,按一级标题分组")
- elif level2_count > 0:
- top_level = 2
- logger.info(f"无一级标题,检测到 {level2_count} 个二级标题,按二级标题分组")
- else:
- top_level = None
- logger.info("无一级或二级标题,按全文处理")
-
- chunks = []
- i = 0
- title_chain = [] # 父级标题链
-
- while i < len(atoms):
- atom = atoms[i]
- if not atom.content.strip():
- i += 1
- continue
-
- # 处理标题
- if atom.atom_type == 'title':
- # 顶级标题:重置标题链
- if top_level and atom.level == top_level:
- title_chain = []
-
- # 更新标题链:保留比当前标题级别低的父级标题
- parent_chain = [t for t in title_chain if t.level < atom.level]
- current_title_chain = parent_chain + [atom]
-
- # 收集该标题后的内容(直到下一个标题)
- i += 1
- content_atoms = []
- while i < len(atoms) and atoms[i].atom_type != 'title':
- content_atoms.append(atoms[i])
- i += 1
-
- # 组装切片:父级标题链 + 当前标题 + 内容
- full_block = current_title_chain + content_atoms
- block_len = sum(len(a) for a in full_block)
-
- if block_len <= self.max_chunk_size and full_block:
- chunks.append(''.join(a.content for a in full_block))
- elif full_block:
- logger.info(f"段落超长 ({block_len} 字符),智能切分")
- sub_chunks = await self.smart_chunk_atoms(full_block)
- chunks.extend(sub_chunks)
-
- # 更新标题链为当前完整链(用于下一个标题的父级链)
- title_chain = current_title_chain
- else:
- # 非标题内容(没有标题的情况)
- content_atoms = []
- while i < len(atoms) and atoms[i].atom_type != 'title':
- content_atoms.append(atoms[i])
- i += 1
-
- if content_atoms:
- full_block = title_chain + content_atoms
- block_len = sum(len(a) for a in full_block)
-
- if block_len <= self.max_chunk_size and full_block:
- chunks.append(''.join(a.content for a in full_block))
- elif full_block:
- logger.info(f"段落超长 ({block_len} 字符),智能切分")
- sub_chunks = await self.smart_chunk_atoms(full_block)
- chunks.extend(sub_chunks)
-
- logger.info(f"最小段落切分: {len(chunks)} 个切片")
- return chunks
-
- async def split_by_max_paragraph(self, atoms: List[MarkdownAtom]) -> List[str]:
- """
- 最大段落模式:自适应选择切分级别
- 优先按一级标题切分,若无一级标题则按二级标题切分
- 不超长直接保存,超长则智能切分(带重叠)
- """
- # 统计标题级别
- level1_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 1)
- level2_count = sum(1 for a in atoms if a.atom_type == 'title' and a.level == 2)
-
- # 动态选择切分级别
- if level1_count > 0:
- split_level = 1
- logger.info(f"检测到 {level1_count} 个一级标题,按一级标题切分")
- elif level2_count > 0:
- split_level = 2
- logger.info(f"无一级标题,检测到 {level2_count} 个二级标题,按二级标题切分")
- else:
- split_level = None
- logger.info("无一级或二级标题,按全文处理")
-
- chunks = []
- i = 0
-
- while i < len(atoms):
- block_atoms = []
-
- # 收集到下一个目标级别标题之前的所有原子
- while i < len(atoms):
- block_atoms.append(atoms[i])
- i += 1
- # 如果下一个是目标级别标题,停止
- if split_level and i < len(atoms) and atoms[i].atom_type == 'title' and atoms[i].level == split_level:
- break
-
- block_len = sum(len(a) for a in block_atoms)
-
- # 不超长,直接保存(不重叠)
- if block_len <= self.max_chunk_size:
- chunks.append(''.join(a.content for a in block_atoms))
- else:
- # 超长,智能切分(带重叠)
- logger.info(f"章节超长 ({block_len} 字符),智能切分")
- sub_chunks = await self.smart_chunk_atoms(block_atoms)
- chunks.extend(sub_chunks)
-
- logger.info(f"最大段落切分: {len(chunks)} 个切片")
- return chunks
-
- async def split_markdown(self, file_path: str, split_mode: str = "min") -> List[str]:
- """
- 切分 Markdown 文件
-
- 参数:
- file_path: MD 文件路径
- split_mode: 'min' 最小段落(标题到标题), 'max' 最大段落(一级标题到一级标题)
-
- 返回:
- List[str]: 切片列表
- """
- logger.info(f"开始切分: {file_path}, 模式: {split_mode}")
-
- content = await self.read_markdown_file(file_path)
- atoms = await self.extract_atoms(content)
-
- if not atoms:
- logger.warning("未提取到原子")
- return []
-
- if split_mode == "min":
- return await self.split_by_min_paragraph(atoms)
- elif split_mode == "max":
- return await self.split_by_max_paragraph(atoms)
- else:
- raise ValueError(f"不支持的模式: {split_mode},请使用 'min' 或 'max'")
- # 使用示例
- if __name__ == "__main__":
- async def main():
- splitter = MarkdownSplitter(max_chunk_size=5000, overlap_ratio=0.1)
-
- test_file = "test.md"
-
- # 最小段落模式
- chunks_min = await splitter.split_markdown(test_file, "min")
- print(f"最小段落: {len(chunks_min)} 个切片")
- for i, chunk in enumerate(chunks_min[:3], 1):
- print(f"\n--- 切片 {i} ({len(chunk)} 字符) ---")
- print(chunk[:200] + "..." if len(chunk) > 200 else chunk)
-
- # 最大段落模式
- chunks_max = await splitter.split_markdown(test_file, "max")
- print(f"\n最大段落: {len(chunks_max)} 个切片")
- for i, chunk in enumerate(chunks_max[:3], 1):
- print(f"\n--- 切片 {i} ({len(chunk)} 字符) ---")
- print(chunk[:200] + "..." if len(chunk) > 200 else chunk)
-
- asyncio.run(main())
|