mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 11:37:14 +00:00
[coati] fix inference profanity check (#3299)
This commit is contained in:
parent
5134ad5d1a
commit
62f7156131
@ -10,3 +10,4 @@ uvicorn
|
|||||||
git+https://github.com/huggingface/transformers
|
git+https://github.com/huggingface/transformers
|
||||||
accelerate
|
accelerate
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
|
jieba
|
@ -2,6 +2,7 @@ import re
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Callable, Generator, List, Optional
|
from typing import Any, Callable, Generator, List, Optional
|
||||||
import json
|
import json
|
||||||
|
import jieba
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -130,10 +131,7 @@ class ChatPromptProcessor:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.context = context
|
self.context = context
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
if len(censored_words) > 0:
|
self.censored_words = set([word.lower() for word in censored_words])
|
||||||
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
|
|
||||||
else:
|
|
||||||
self.censored_pat = None
|
|
||||||
# These will be initialized after the first call of preprocess_prompt()
|
# These will be initialized after the first call of preprocess_prompt()
|
||||||
self.context_len: Optional[int] = None
|
self.context_len: Optional[int] = None
|
||||||
self.dialogue_placeholder_len: Optional[int] = None
|
self.dialogue_placeholder_len: Optional[int] = None
|
||||||
@ -179,9 +177,10 @@ class ChatPromptProcessor:
|
|||||||
return output.strip()
|
return output.strip()
|
||||||
|
|
||||||
def has_censored_words(self, text: str) -> bool:
|
def has_censored_words(self, text: str) -> bool:
|
||||||
if self.censored_pat is None:
|
if len(self.censored_words) == 0:
|
||||||
return False
|
return False
|
||||||
return self.censored_pat.search(text) is not None
|
intersection = set(jieba.cut(text.lower())) & self.censored_words
|
||||||
|
return len(intersection) > 0
|
||||||
|
|
||||||
class LockedIterator:
|
class LockedIterator:
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user