mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-03 17:19:51 +00:00
[coati] inference supports profanity check (#3295)
This commit is contained in:
parent
ce2cafae76
commit
73b542a124
@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded
|
|||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
|
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json
|
||||||
|
|
||||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||||
MAX_LEN = 512
|
MAX_LEN = 512
|
||||||
@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request):
|
|||||||
@limiter.limit('1/second')
|
@limiter.limit('1/second')
|
||||||
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||||
|
if prompt_processor.has_censored_words(prompt):
|
||||||
|
return prompt_processor.SAFE_RESPONSE
|
||||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||||
with running_lock:
|
with running_lock:
|
||||||
output = model.generate(**inputs, **data.dict(exclude={'history'}))
|
output = model.generate(**inputs, **data.dict(exclude={'history'}))
|
||||||
@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
|
|||||||
prompt_len = inputs['input_ids'].size(1)
|
prompt_len = inputs['input_ids'].size(1)
|
||||||
response = output[0, prompt_len:]
|
response = output[0, prompt_len:]
|
||||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||||
return prompt_processor.postprocess_output(out_string)
|
out_string = prompt_processor.postprocess_output(out_string)
|
||||||
|
if prompt_processor.has_censored_words(out_string):
|
||||||
|
return prompt_processor.SAFE_RESPONSE
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -140,13 +145,19 @@ if __name__ == '__main__':
|
|||||||
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
||||||
parser.add_argument('--http_host', default='0.0.0.0')
|
parser.add_argument('--http_host', default='0.0.0.0')
|
||||||
parser.add_argument('--http_port', type=int, default=7070)
|
parser.add_argument('--http_port', type=int, default=7070)
|
||||||
|
parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.quant == '4bit':
|
if args.quant == '4bit':
|
||||||
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||||
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN)
|
|
||||||
|
if args.profanity_file is not None:
|
||||||
|
censored_words = load_json(args.profanity_file)
|
||||||
|
else:
|
||||||
|
censored_words = []
|
||||||
|
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
|
||||||
|
|
||||||
if args.quant == '4bit':
|
if args.quant == '4bit':
|
||||||
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Callable, Generator, List, Optional
|
from typing import Any, Callable, Generator, List, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
|||||||
|
|
||||||
|
|
||||||
class ChatPromptProcessor:
|
class ChatPromptProcessor:
|
||||||
|
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
|
||||||
|
|
||||||
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.context = context
|
self.context = context
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
if len(censored_words) > 0:
|
||||||
|
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
|
||||||
|
else:
|
||||||
|
self.censored_pat = None
|
||||||
# These will be initialized after the first call of preprocess_prompt()
|
# These will be initialized after the first call of preprocess_prompt()
|
||||||
self.context_len: Optional[int] = None
|
self.context_len: Optional[int] = None
|
||||||
self.dialogue_placeholder_len: Optional[int] = None
|
self.dialogue_placeholder_len: Optional[int] = None
|
||||||
@ -172,6 +178,10 @@ class ChatPromptProcessor:
|
|||||||
output = STOP_PAT.sub('', output)
|
output = STOP_PAT.sub('', output)
|
||||||
return output.strip()
|
return output.strip()
|
||||||
|
|
||||||
|
def has_censored_words(self, text: str) -> bool:
|
||||||
|
if self.censored_pat is None:
|
||||||
|
return False
|
||||||
|
return self.censored_pat.search(text) is not None
|
||||||
|
|
||||||
class LockedIterator:
|
class LockedIterator:
|
||||||
|
|
||||||
@ -185,3 +195,7 @@ class LockedIterator:
|
|||||||
def __next__(self):
|
def __next__(self):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
return next(self.it)
|
return next(self.it)
|
||||||
|
|
||||||
|
def load_json(path: str):
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
Loading…
Reference in New Issue
Block a user