| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- from dataclasses import dataclass
- @dataclass(frozen=True)
- class SensitiveWord:
- word: str
- category: str
- level: str = "high"
- from pathlib import Path
- from typing import Dict
- import time
- class TxtSensitiveWordStore:
- def __init__(
- self,
- dir_path: str,
- max_match_len: int = 30,
- default_level: str = "high",
- ):
- self.dir = Path(dir_path)
- self.max_match_len_cfg = max_match_len
- self.default_level = default_level
- self.words: Dict[str, SensitiveWord] = {}
- self.max_len = 0
- self._mtime_map: Dict[Path, float] = {}
- self.load()
- def load(self):
- words = {}
- max_len = 0
- mtime_map = {}
- for path in self.dir.glob("*.txt"):
- category = path.stem
- mtime_map[path] = path.stat().st_mtime
- with open(path, "r", encoding="utf-8") as f:
- for line in f:
- w = line.strip()
- if not w or w.startswith("#"):
- continue
- self._validate_word(w)
- words[w] = SensitiveWord(
- word=w,
- category=category,
- level=self.default_level,
- )
- max_len = max(max_len, len(w))
- self.words = words
- self.max_len = min(self.max_match_len_cfg, max_len)
- self._mtime_map = mtime_map
- def maybe_reload(self):
- for path in self.dir.glob("*.txt"):
- mtime = path.stat().st_mtime
- if self._mtime_map.get(path) != mtime:
- self.load()
- return
- @staticmethod
- def _validate_word(word: str):
- import re
- if re.search(r"[a-zA-Z0-9\.]", word):
- return
- if len(word) > 40:
- raise ValueError(f"敏感词过长:{word}")
- if " " in word:
- raise ValueError(f"敏感词包含空格:{word}")
- class SecurityTokenizer:
- def __init__(self, store: TxtSensitiveWordStore):
- self.store = store
- def detect(self, text: str) -> list[SensitiveWord]:
- self.store.maybe_reload()
- hits = []
- words = self.store.words
- max_len = self.store.max_len
- i = 0
- n = len(text)
- while i < n:
- matched = None
- for l in range(max_len, 0, -1):
- if i + l > n:
- continue
- sub = text[i:i + l]
- if sub in words:
- matched = words[sub]
- break
- if matched:
- hits.append(matched)
- i += len(matched.word)
- else:
- i += 1
- return hits
- def build_warning(hits: list[SensitiveWord]) -> str:
- uniq = {(h.word, h.category) for h in hits}
- words = "、".join(f"“{w}”" for w, _ in uniq)
- return f"检测到敏感词 {words},请确认后重新输入。"
- if __name__ == "__main__":
- store = TxtSensitiveWordStore("./rag/sensitive_words/")
- tokenizer = SecurityTokenizer(store)
- text = "如何制作炸弹?"
- hits = tokenizer.detect(text)
- if hits:
- print(build_warning(hits))
|