[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -16,7 +16,7 @@ from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512
running_lock = Lock()
@@ -36,11 +36,11 @@ app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# set CORS
origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
if origin_spec_from_env is not None:
# allow CORS from the specified origins
origins = os.environ['CORS_ORIGIN'].split(',')
origins = os.environ["CORS_ORIGIN"].split(",")
else:
# allow CORS from all origins
origins = ["*"]
@@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
# TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
'max_generate_tokens': max_new_tokens,
'early_stopping': True,
'top_k': top_k,
'top_p': top_p,
'temperature': temperature,
'prepare_inputs_fn': model.prepare_inputs_for_generation,
'update_model_kwargs_fn': update_model_kwargs_fn,
"max_generate_tokens": max_new_tokens,
"early_stopping": True,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"prepare_inputs_fn": model.prepare_inputs_for_generation,
"update_model_kwargs_fn": update_model_kwargs_fn,
}
is_first_word = True
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
if is_first_word:
out_string = out_string.lstrip()
is_first_word = False
elif current_sub_tokens[0].startswith(''):
elif current_sub_tokens[0].startswith(""):
# whitespace will be ignored by the frontend
out_string = ' ' + out_string
out_string = " " + out_string
yield out_string
@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator):
if await request.is_disconnected():
break
try:
yield {'event': 'generate', 'data': next(generator)}
yield {"event": "generate", "data": next(generator)}
except StopIteration:
yield {'event': 'end', 'data': ''}
yield {"event": "end", "data": ""}
break
@app.post('/generate/stream')
@limiter.limit('1/second')
@app.post("/generate/stream")
@limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
event_source = event_generator(
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
)
return EventSourceResponse(event_source)
@app.post('/generate')
@limiter.limit('1/second')
@app.post("/generate")
@limiter.limit("1/second")
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
output = model.generate(**inputs, **data.dict(exclude={'history'}))
output = model.generate(**inputs, **data.dict(exclude={"history"}))
output = output.cpu()
prompt_len = inputs['input_ids'].size(1)
prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
out_string = prompt_processor.postprocess_output(out_string)
@@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
return out_string
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'pretrained',
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
parser.add_argument('--quant',
choices=['8bit', '4bit'],
default=None,
help='Quantization mode. Default: None (no quantization, fp16).')
"pretrained",
help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
)
parser.add_argument(
'--gptq_checkpoint',
"--quant",
choices=["8bit", "4bit"],
default=None,
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
parser.add_argument('--gptq_group_size',
type=int,
default=128,
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--profanity_file',
default=None,
help='Path to profanity words list. It should be a JSON file containing a list of words.')
help="Quantization mode. Default: None (no quantization, fp16).",
)
parser.add_argument(
"--gptq_checkpoint",
default=None,
help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
)
parser.add_argument(
"--gptq_group_size",
type=int,
default=128,
help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
)
parser.add_argument("--http_host", default="0.0.0.0")
parser.add_argument("--http_port", type=int, default=7070)
parser.add_argument(
"--profanity_file",
default=None,
help="Path to profanity words list. It should be a JSON file containing a list of words.",
)
args = parser.parse_args()
if args.quant == '4bit':
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
if args.quant == "4bit":
assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
@@ -161,7 +170,7 @@ if __name__ == '__main__':
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit':
if args.quant == "4bit":
with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
@@ -170,12 +179,12 @@ if __name__ == '__main__':
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
load_in_8bit=(args.quant == '8bit'),
load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
if args.quant != '8bit':
model.half() # seems to fix bugs for some users.
if args.quant != "8bit":
model.half() # seems to fix bugs for some users.
model.eval()
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)