[gemini] fixes for benchmarking (#5847)

* [gemini] fix missing return

* [gemini] fix missing arg pass

* [gemini] use gather tensor instead of list

* [test] enable flash attention for benchmark by default

* [test] enable flash attention for benchmark by default

---------

Co-authored-by: genghaozhe <939857490@qq.com>
This commit is contained in:
botbw
2024-06-26 15:52:09 +08:00
committed by GitHub
parent 2a25a2aff7
commit 8e718a1421
5 changed files with 27 additions and 15 deletions

View File

@@ -253,8 +253,13 @@ def main():
init_kwargs["empty_init"] = False
with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
**init_kwargs,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if config.model_type == "chatglm":
@@ -286,7 +291,7 @@ def main():
with get_profile_context(
args.profile,
1,
args.ignore_steps,
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof: