mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
This commit is contained in:
@@ -752,7 +752,19 @@ We support the method introduced in the paper [ORPO: Monolithic Preference Optim
|
||||
</p>
|
||||
|
||||
## Hardware Requirements
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use an H800 GPU with 80GB VRAM.
|
||||
|
||||
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=72390.95 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=19412.77 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=43446.31 MB
|
||||
- zero2, micro batch size=16, VRAM Usage=58082.30 MB
|
||||
- zero2, micro batch size=8, lora_rank=8, VRAM Usage=21167.73 MB
|
||||
- zero2, micro batch size=8, lora_rank=32, VRAM Usage=21344.17 MB
|
||||
|
||||
For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model (llama2-7B-hf) on a dummy dataset with a sequence length of 2048 and a layout length of 512 with different tp_size (equal to the number of GPUs).
|
||||
| PPO | tp=8 | tp=4 |
|
||||
|-------|---------------|---------------|
|
||||
| bs=1 | 18485.19 MB | 42934.45 MB |
|
||||
@@ -763,12 +775,31 @@ For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM
|
||||
|
||||
For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
|
||||
- 1 H800 GPU
|
||||
- zero2-cpu, batch size=2, VRAM Usage=49873.90 MB
|
||||
- zero2-cpu, batch size=4, VRAM Usage=60998.22 MB
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=2, VRAM Usage=36989.37 MB
|
||||
- zero2-cpu, micro batch size=4, VRAM Usage=48081.67 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2, batch size=4, VRAM Usage=67544.47 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=67483.44 MB
|
||||
|
||||
For SimPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM 25705.26 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=73375.04 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=36709.36 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=44330.90 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=56086.12 MB
|
||||
|
||||
For ORPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length.
|
||||
|
||||
- 2 H800 GPU
|
||||
- zero2-cpu, micro batch size=4, VRAM 26693.38 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=74332.65 MB
|
||||
- 4 H800 GPUs
|
||||
- zero2_cpu, micro batch size=8, VRAM Usage=38709.73 MB
|
||||
- zero2, micro batch size=4, VRAM Usage=45309.52 MB
|
||||
- zero2, micro batch size=8, VRAM Usage=58086.37 MB
|
||||
|
||||
## List of Supported Models
|
||||
|
||||
|
@@ -128,16 +128,14 @@ def train(args):
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
print("ref_model is None", args.disable_reference_model, ref_model is None)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
|
@@ -118,12 +118,11 @@ def train(args):
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
|
@@ -122,13 +122,11 @@ def train(args):
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
|
||||
|
Reference in New Issue
Block a user