sensitive_words_detection.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from dataclasses import dataclass
  2. @dataclass(frozen=True)
  3. class SensitiveWord:
  4. word: str
  5. category: str
  6. level: str = "high"
  7. from pathlib import Path
  8. from typing import Dict
  9. import time
  10. class TxtSensitiveWordStore:
  11. def __init__(
  12. self,
  13. dir_path: str,
  14. max_match_len: int = 30,
  15. default_level: str = "high",
  16. ):
  17. self.dir = Path(dir_path)
  18. self.max_match_len_cfg = max_match_len
  19. self.default_level = default_level
  20. self.words: Dict[str, SensitiveWord] = {}
  21. self.max_len = 0
  22. self._mtime_map: Dict[Path, float] = {}
  23. self.load()
  24. def load(self):
  25. words = {}
  26. max_len = 0
  27. mtime_map = {}
  28. for path in self.dir.glob("*.txt"):
  29. category = path.stem
  30. mtime_map[path] = path.stat().st_mtime
  31. with open(path, "r", encoding="utf-8") as f:
  32. for line in f:
  33. w = line.strip()
  34. if not w or w.startswith("#"):
  35. continue
  36. self._validate_word(w)
  37. words[w] = SensitiveWord(
  38. word=w,
  39. category=category,
  40. level=self.default_level,
  41. )
  42. max_len = max(max_len, len(w))
  43. self.words = words
  44. self.max_len = min(self.max_match_len_cfg, max_len)
  45. self._mtime_map = mtime_map
  46. def maybe_reload(self):
  47. for path in self.dir.glob("*.txt"):
  48. mtime = path.stat().st_mtime
  49. if self._mtime_map.get(path) != mtime:
  50. self.load()
  51. return
  52. @staticmethod
  53. def _validate_word(word: str):
  54. import re
  55. if re.search(r"[a-zA-Z0-9\.]", word):
  56. return
  57. if len(word) > 40:
  58. raise ValueError(f"敏感词过长:{word}")
  59. if " " in word:
  60. raise ValueError(f"敏感词包含空格:{word}")
  61. class SecurityTokenizer:
  62. def __init__(self, store: TxtSensitiveWordStore):
  63. self.store = store
  64. def detect(self, text: str) -> list[SensitiveWord]:
  65. self.store.maybe_reload()
  66. hits = []
  67. words = self.store.words
  68. max_len = self.store.max_len
  69. i = 0
  70. n = len(text)
  71. while i < n:
  72. matched = None
  73. for l in range(max_len, 0, -1):
  74. if i + l > n:
  75. continue
  76. sub = text[i:i + l]
  77. if sub in words:
  78. matched = words[sub]
  79. break
  80. if matched:
  81. hits.append(matched)
  82. i += len(matched.word)
  83. else:
  84. i += 1
  85. return hits
  86. def build_warning(hits: list[SensitiveWord]) -> str:
  87. uniq = {(h.word, h.category) for h in hits}
  88. words = "、".join(f"“{w}”" for w, _ in uniq)
  89. return f"检测到敏感词 {words},请确认后重新输入。"
  90. if __name__ == "__main__":
  91. store = TxtSensitiveWordStore("./rag/sensitive_words/")
  92. tokenizer = SecurityTokenizer(store)
  93. text = "如何制作炸弹?"
  94. hits = tokenizer.detect(text)
  95. if hits:
  96. print(build_warning(hits))