[coati] inference supports profanity check (#3295)

This commit is contained in:
ver217 2023-03-29 02:14:35 +08:00 committed by GitHub
parent ce2cafae76
commit 73b542a124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 4 deletions

View File

@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 512 MAX_LEN = 512
@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request):
@limiter.limit('1/second') @limiter.limit('1/second')
def generate_no_stream(data: GenerationTaskReq, request: Request): def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock: with running_lock:
output = model.generate(**inputs, **data.dict(exclude={'history'})) output = model.generate(**inputs, **data.dict(exclude={'history'}))
@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_len = inputs['input_ids'].size(1) prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:] response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True) out_string = tokenizer.decode(response, skip_special_tokens=True)
return prompt_processor.postprocess_output(out_string) out_string = prompt_processor.postprocess_output(out_string)
if prompt_processor.has_censored_words(out_string):
return prompt_processor.SAFE_RESPONSE
return out_string
if __name__ == '__main__': if __name__ == '__main__':
@ -140,13 +145,19 @@ if __name__ == '__main__':
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
parser.add_argument('--http_host', default='0.0.0.0') parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070) parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
args = parser.parse_args() args = parser.parse_args()
if args.quant == '4bit': if args.quant == '4bit':
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
tokenizer = AutoTokenizer.from_pretrained(args.pretrained) tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN)
if args.profanity_file is not None:
censored_words = load_json(args.profanity_file)
else:
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit': if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)

View File

@ -1,6 +1,7 @@
import re import re
from threading import Lock from threading import Lock
from typing import Any, Callable, Generator, List, Optional from typing import Any, Callable, Generator, List, Optional
import json
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
class ChatPromptProcessor: class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
def __init__(self, tokenizer, context: str, max_len: int = 2048): def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.context = context self.context = context
self.max_len = max_len self.max_len = max_len
if len(censored_words) > 0:
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
else:
self.censored_pat = None
# These will be initialized after the first call of preprocess_prompt() # These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None self.dialogue_placeholder_len: Optional[int] = None
@ -172,6 +178,10 @@ class ChatPromptProcessor:
output = STOP_PAT.sub('', output) output = STOP_PAT.sub('', output)
return output.strip() return output.strip()
def has_censored_words(self, text: str) -> bool:
if self.censored_pat is None:
return False
return self.censored_pat.search(text) is not None
class LockedIterator: class LockedIterator:
@ -185,3 +195,7 @@ class LockedIterator:
def __next__(self): def __next__(self):
with self.lock: with self.lock:
return next(self.it) return next(self.it)
def load_json(path: str):
with open(path) as f:
return json.load(f)