mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -84,28 +84,34 @@ inst = [instructions[0]] * 4
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'pretrained',
|
||||
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
|
||||
parser.add_argument('--quant',
|
||||
choices=['8bit', '4bit'],
|
||||
default=None,
|
||||
help='Quantization mode. Default: None (no quantization, fp16).')
|
||||
"pretrained",
|
||||
help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gptq_checkpoint',
|
||||
"--quant",
|
||||
choices=["8bit", "4bit"],
|
||||
default=None,
|
||||
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
|
||||
parser.add_argument('--gptq_group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
||||
help="Quantization mode. Default: None (no quantization, fp16).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq_checkpoint",
|
||||
default=None,
|
||||
help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq_group_size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quant == '4bit':
|
||||
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
||||
if args.quant == "4bit":
|
||||
assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
|
||||
if args.quant == '4bit':
|
||||
if args.quant == "4bit":
|
||||
with low_resource_init():
|
||||
config = LlamaConfig.from_pretrained(args.pretrained)
|
||||
model = LlamaForCausalLM(config)
|
||||
@@ -114,12 +120,12 @@ if __name__ == "__main__":
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
load_in_8bit=(args.quant == '8bit'),
|
||||
load_in_8bit=(args.quant == "8bit"),
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
if args.quant != '8bit':
|
||||
model.half() # seems to fix bugs for some users.
|
||||
if args.quant != "8bit":
|
||||
model.half() # seems to fix bugs for some users.
|
||||
model.eval()
|
||||
|
||||
total_tokens = 0
|
||||
@@ -129,7 +135,7 @@ if __name__ == "__main__":
|
||||
resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1)
|
||||
total_tokens += tokens
|
||||
print(f"Response: {resp}")
|
||||
print('\n----------------------------\n')
|
||||
print("\n----------------------------\n")
|
||||
duration = time() - start
|
||||
print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s')
|
||||
print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
|
||||
print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s")
|
||||
print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
||||
|
@@ -1,26 +1,26 @@
|
||||
from json import JSONDecodeError
|
||||
|
||||
from locust import HttpUser, task
|
||||
|
||||
samples = [[
|
||||
dict(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
dict(instruction='continue this talk', response=''),
|
||||
], [
|
||||
dict(instruction='Who is the best player in the history of NBA?', response=''),
|
||||
]]
|
||||
samples = [
|
||||
[
|
||||
dict(
|
||||
instruction="Who is the best player in the history of NBA?",
|
||||
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
|
||||
),
|
||||
dict(instruction="continue this talk", response=""),
|
||||
],
|
||||
[
|
||||
dict(instruction="Who is the best player in the history of NBA?", response=""),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
|
||||
@task
|
||||
def generate(self):
|
||||
for sample in samples:
|
||||
data = {'max_new_tokens': 64, 'history': sample}
|
||||
with self.client.post('/generate', json=data, catch_response=True) as response:
|
||||
data = {"max_new_tokens": 64, "history": sample}
|
||||
with self.client.post("/generate", json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
||||
response.failure("Response wrong")
|
||||
|
@@ -16,7 +16,7 @@ from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
|
||||
MAX_LEN = 512
|
||||
running_lock = Lock()
|
||||
|
||||
@@ -36,11 +36,11 @@ app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# set CORS
|
||||
origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
|
||||
origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
|
||||
|
||||
if origin_spec_from_env is not None:
|
||||
# allow CORS from the specified origins
|
||||
origins = os.environ['CORS_ORIGIN'].split(',')
|
||||
origins = os.environ["CORS_ORIGIN"].split(",")
|
||||
else:
|
||||
# allow CORS from all origins
|
||||
origins = ["*"]
|
||||
@@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
# TODO(ver217): streaming generation does not support repetition_penalty now
|
||||
model_kwargs = {
|
||||
'max_generate_tokens': max_new_tokens,
|
||||
'early_stopping': True,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'temperature': temperature,
|
||||
'prepare_inputs_fn': model.prepare_inputs_for_generation,
|
||||
'update_model_kwargs_fn': update_model_kwargs_fn,
|
||||
"max_generate_tokens": max_new_tokens,
|
||||
"early_stopping": True,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"prepare_inputs_fn": model.prepare_inputs_for_generation,
|
||||
"update_model_kwargs_fn": update_model_kwargs_fn,
|
||||
}
|
||||
is_first_word = True
|
||||
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
|
||||
@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
|
||||
if is_first_word:
|
||||
out_string = out_string.lstrip()
|
||||
is_first_word = False
|
||||
elif current_sub_tokens[0].startswith('▁'):
|
||||
elif current_sub_tokens[0].startswith("▁"):
|
||||
# whitespace will be ignored by the frontend
|
||||
out_string = ' ' + out_string
|
||||
out_string = " " + out_string
|
||||
yield out_string
|
||||
|
||||
|
||||
@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
yield {'event': 'generate', 'data': next(generator)}
|
||||
yield {"event": "generate", "data": next(generator)}
|
||||
except StopIteration:
|
||||
yield {'event': 'end', 'data': ''}
|
||||
yield {"event": "end", "data": ""}
|
||||
break
|
||||
|
||||
|
||||
@app.post('/generate/stream')
|
||||
@limiter.limit('1/second')
|
||||
@app.post("/generate/stream")
|
||||
@limiter.limit("1/second")
|
||||
def generate(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||
event_source = event_generator(
|
||||
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
|
||||
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
|
||||
)
|
||||
return EventSourceResponse(event_source)
|
||||
|
||||
|
||||
@app.post('/generate')
|
||||
@limiter.limit('1/second')
|
||||
@app.post("/generate")
|
||||
@limiter.limit("1/second")
|
||||
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||
if prompt_processor.has_censored_words(prompt):
|
||||
return prompt_processor.SAFE_RESPONSE
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
with running_lock:
|
||||
output = model.generate(**inputs, **data.dict(exclude={'history'}))
|
||||
output = model.generate(**inputs, **data.dict(exclude={"history"}))
|
||||
output = output.cpu()
|
||||
prompt_len = inputs['input_ids'].size(1)
|
||||
prompt_len = inputs["input_ids"].size(1)
|
||||
response = output[0, prompt_len:]
|
||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||
out_string = prompt_processor.postprocess_output(out_string)
|
||||
@@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||
return out_string
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'pretrained',
|
||||
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
|
||||
parser.add_argument('--quant',
|
||||
choices=['8bit', '4bit'],
|
||||
default=None,
|
||||
help='Quantization mode. Default: None (no quantization, fp16).')
|
||||
"pretrained",
|
||||
help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gptq_checkpoint',
|
||||
"--quant",
|
||||
choices=["8bit", "4bit"],
|
||||
default=None,
|
||||
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
|
||||
parser.add_argument('--gptq_group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--profanity_file',
|
||||
default=None,
|
||||
help='Path to profanity words list. It should be a JSON file containing a list of words.')
|
||||
help="Quantization mode. Default: None (no quantization, fp16).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq_checkpoint",
|
||||
default=None,
|
||||
help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gptq_group_size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
|
||||
)
|
||||
parser.add_argument("--http_host", default="0.0.0.0")
|
||||
parser.add_argument("--http_port", type=int, default=7070)
|
||||
parser.add_argument(
|
||||
"--profanity_file",
|
||||
default=None,
|
||||
help="Path to profanity words list. It should be a JSON file containing a list of words.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quant == '4bit':
|
||||
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
||||
if args.quant == "4bit":
|
||||
assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
|
||||
@@ -161,7 +170,7 @@ if __name__ == '__main__':
|
||||
censored_words = []
|
||||
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
|
||||
|
||||
if args.quant == '4bit':
|
||||
if args.quant == "4bit":
|
||||
with low_resource_init():
|
||||
config = LlamaConfig.from_pretrained(args.pretrained)
|
||||
model = LlamaForCausalLM(config)
|
||||
@@ -170,12 +179,12 @@ if __name__ == '__main__':
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
load_in_8bit=(args.quant == '8bit'),
|
||||
load_in_8bit=(args.quant == "8bit"),
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
if args.quant != '8bit':
|
||||
model.half() # seems to fix bugs for some users.
|
||||
if args.quant != "8bit":
|
||||
model.half() # seems to fix bugs for some users.
|
||||
model.eval()
|
||||
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
|
@@ -3,41 +3,49 @@ import os
|
||||
from transformers import AutoTokenizer
|
||||
from utils import ChatPromptProcessor, Dialogue
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH'])
|
||||
CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
|
||||
tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
|
||||
|
||||
samples = [
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 128,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
(
|
||||
[
|
||||
Dialogue(
|
||||
instruction="Who is the best player in the history of NBA?",
|
||||
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
|
||||
),
|
||||
Dialogue(instruction="continue this talk", response=""),
|
||||
],
|
||||
128,
|
||||
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
|
||||
),
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 200,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
(
|
||||
[
|
||||
Dialogue(
|
||||
instruction="Who is the best player in the history of NBA?",
|
||||
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
|
||||
),
|
||||
Dialogue(instruction="continue this talk", response=""),
|
||||
],
|
||||
200,
|
||||
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
|
||||
),
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 211,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
|
||||
(
|
||||
[
|
||||
Dialogue(
|
||||
instruction="Who is the best player in the history of NBA?",
|
||||
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
|
||||
),
|
||||
Dialogue(instruction="continue this talk", response=""),
|
||||
],
|
||||
211,
|
||||
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
|
||||
),
|
||||
([
|
||||
Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
|
||||
], 128,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
|
||||
(
|
||||
[
|
||||
Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
|
||||
],
|
||||
128,
|
||||
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -49,5 +57,5 @@ def test_chat_prompt_processor():
|
||||
assert prompt == result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_chat_prompt_processor()
|
||||
|
@@ -20,9 +20,9 @@ except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
def prepare_logits_processor(
|
||||
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
@@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample_streamingly(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_generate_tokens: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> Generator:
|
||||
|
||||
def sample_streamingly(
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_generate_tokens: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs,
|
||||
) -> Generator:
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(max_generate_tokens):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
model_inputs = (
|
||||
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
|
||||
)
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
@@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
class Dialogue(BaseModel):
|
||||
instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
|
||||
response: str = Field(example='')
|
||||
instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
|
||||
response: str = Field(example="")
|
||||
|
||||
|
||||
def _format_dialogue(instruction: str, response: str = ''):
|
||||
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
||||
def _format_dialogue(instruction: str, response: str = ""):
|
||||
return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
|
||||
|
||||
|
||||
STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
|
||||
STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
|
||||
|
||||
|
||||
class ChatPromptProcessor:
|
||||
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
|
||||
SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
|
||||
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
|
||||
self.tokenizer = tokenizer
|
||||
@@ -138,42 +140,48 @@ class ChatPromptProcessor:
|
||||
|
||||
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
|
||||
if self.context_len is None:
|
||||
self.context_len = len(self.tokenizer(self.context)['input_ids'])
|
||||
self.context_len = len(self.tokenizer(self.context)["input_ids"])
|
||||
if self.dialogue_placeholder_len is None:
|
||||
self.dialogue_placeholder_len = len(
|
||||
self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
|
||||
self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
prompt = self.context
|
||||
# the last dialogue must be in the prompt
|
||||
last_dialogue = history.pop()
|
||||
# the response of the last dialogue is empty
|
||||
assert last_dialogue.response == ''
|
||||
if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
|
||||
['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
|
||||
assert last_dialogue.response == ""
|
||||
if (
|
||||
len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
|
||||
+ max_new_tokens
|
||||
+ self.context_len
|
||||
>= self.max_len
|
||||
):
|
||||
# to avoid truncate placeholder, apply truncate to the original instruction
|
||||
instruction_truncated = self.tokenizer(last_dialogue.instruction,
|
||||
add_special_tokens=False,
|
||||
truncation=True,
|
||||
max_length=(self.max_len - max_new_tokens - self.context_len -
|
||||
self.dialogue_placeholder_len))['input_ids']
|
||||
instruction_truncated = self.tokenizer(
|
||||
last_dialogue.instruction,
|
||||
add_special_tokens=False,
|
||||
truncation=True,
|
||||
max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
|
||||
)["input_ids"]
|
||||
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
|
||||
prompt += _format_dialogue(instruction_truncated)
|
||||
return prompt
|
||||
|
||||
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
|
||||
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
|
||||
|
||||
rows = []
|
||||
for dialogue in history[::-1]:
|
||||
text = _format_dialogue(dialogue.instruction, dialogue.response)
|
||||
cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
|
||||
cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
|
||||
if res_len - cur_len < 0:
|
||||
break
|
||||
res_len -= cur_len
|
||||
rows.insert(0, text)
|
||||
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
||||
prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
|
||||
return prompt
|
||||
|
||||
def postprocess_output(self, output: str) -> str:
|
||||
output = STOP_PAT.sub('', output)
|
||||
output = STOP_PAT.sub("", output)
|
||||
return output.strip()
|
||||
|
||||
def has_censored_words(self, text: str) -> bool:
|
||||
@@ -184,7 +192,6 @@ class ChatPromptProcessor:
|
||||
|
||||
|
||||
class LockedIterator:
|
||||
|
||||
def __init__(self, it, lock: Lock) -> None:
|
||||
self.lock = lock
|
||||
self.it = iter(it)
|
||||
|
Reference in New Issue
Block a user