mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[gemini] fixes for benchmarking (#5847)
* [gemini] fix missing return * [gemini] fix missing arg pass * [gemini] use gather tensor instead of list * [test] enable flash attention for benchmark by default * [test] enable flash attention for benchmark by default --------- Co-authored-by: genghaozhe <939857490@qq.com>
This commit is contained in:
@@ -253,8 +253,13 @@ def main():
|
||||
init_kwargs["empty_init"] = False
|
||||
|
||||
with init_ctx:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
|
||||
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
trust_remote_code=True,
|
||||
**init_kwargs,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if config.model_type == "chatglm":
|
||||
@@ -286,7 +291,7 @@ def main():
|
||||
|
||||
with get_profile_context(
|
||||
args.profile,
|
||||
1,
|
||||
args.ignore_steps,
|
||||
len(dataloader) - 1,
|
||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||
) as prof:
|
||||
|
Reference in New Issue
Block a user