mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 21:17:08 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -91,7 +91,7 @@ def benchmark_inference(args):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
model = transformers.LlamaForCausalLM(config).cuda()
|
||||
model = model.eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
if args.dtype == "fp16":
|
||||
model = model.half()
|
||||
|
Reference in New Issue
Block a user