From 73b542a12466aa4906670562cc3ee7c2a7c729d2 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 29 Mar 2023 02:14:35 +0800 Subject: [PATCH] [coati] inference supports profanity check (#3295) --- applications/Chat/inference/server.py | 17 ++++++++++++++--- applications/Chat/inference/utils.py | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index f67b78a08..b46272993 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sse_starlette.sse import EventSourceResponse from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM -from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn +from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' MAX_LEN = 512 @@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request): @limiter.limit('1/second') def generate_no_stream(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) + if prompt_processor.has_censored_words(prompt): + return prompt_processor.SAFE_RESPONSE inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} with running_lock: output = model.generate(**inputs, **data.dict(exclude={'history'})) @@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): prompt_len = inputs['input_ids'].size(1) response = output[0, prompt_len:] out_string = tokenizer.decode(response, skip_special_tokens=True) - return prompt_processor.postprocess_output(out_string) + out_string = prompt_processor.postprocess_output(out_string) + if prompt_processor.has_censored_words(out_string): + return prompt_processor.SAFE_RESPONSE + return out_string if __name__ == '__main__': @@ -140,13 +145,19 @@ if __name__ == '__main__': help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') parser.add_argument('--http_host', default='0.0.0.0') parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.') args = parser.parse_args() if args.quant == '4bit': assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' tokenizer = AutoTokenizer.from_pretrained(args.pretrained) - prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN) + + if args.profanity_file is not None: + censored_words = load_json(args.profanity_file) + else: + censored_words = [] + prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) if args.quant == '4bit': model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index a01983de3..1bb0e82ba 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -1,6 +1,7 @@ import re from threading import Lock from typing import Any, Callable, Generator, List, Optional +import json import torch import torch.distributed as dist @@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) class ChatPromptProcessor: + SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' - def __init__(self, tokenizer, context: str, max_len: int = 2048): + def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]): self.tokenizer = tokenizer self.context = context self.max_len = max_len + if len(censored_words) > 0: + self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I) + else: + self.censored_pat = None # These will be initialized after the first call of preprocess_prompt() self.context_len: Optional[int] = None self.dialogue_placeholder_len: Optional[int] = None @@ -172,6 +178,10 @@ class ChatPromptProcessor: output = STOP_PAT.sub('', output) return output.strip() + def has_censored_words(self, text: str) -> bool: + if self.censored_pat is None: + return False + return self.censored_pat.search(text) is not None class LockedIterator: @@ -185,3 +195,7 @@ class LockedIterator: def __next__(self): with self.lock: return next(self.it) + +def load_json(path: str): + with open(path) as f: + return json.load(f) \ No newline at end of file