mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Coati] Train DPO using PP (#6054)
* update dpo * remove unsupport plugin * update msg * update dpo * remove unsupport plugin * update msg * update template * update dataset * add pp for dpo * update dpo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add dpo fn * update dpo * update dpo * update dpo * update dpo * minor update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update loss * update help * polish code --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@@ -29,8 +29,6 @@ def train(args):
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
@@ -46,7 +44,7 @@ def train(args):
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
@@ -56,14 +54,6 @@ def train(args):
|
||||
enable_gradient_accumulation=True,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
@@ -92,20 +82,24 @@ def train(args):
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
ref_booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# Temp Fix: Disable lazy init due to version conflict
|
||||
# init_ctx = (
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
ref_plugin = HybridParallelPlugin(
|
||||
tp_size=args.ref_tp,
|
||||
pp_size=1,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
ref_booster = Booster(plugin=ref_plugin)
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
@@ -130,6 +124,7 @@ def train(args):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
@@ -139,7 +134,9 @@ def train(args):
|
||||
disable_dropout(ref_model)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
# Make sure gradient checkpointing can be activated.
|
||||
model.train()
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing.
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
@@ -169,7 +166,7 @@ def train(args):
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
# Configure dataset
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||
@@ -213,14 +210,15 @@ def train(args):
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
if ref_model is not None:
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
@@ -312,7 +310,7 @@ if __name__ == "__main__":
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
@@ -342,22 +340,35 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--microbatch_size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
|
||||
)
|
||||
# Parameter for reference model
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--ref_tp",
|
||||
type=int,
|
||||
default=1,
|
||||
help="TP size for reference model; used only when reference model is too large.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
|
@@ -68,7 +68,7 @@ def train(args):
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
debugging purpose acceleration, for debugging purpose
|
||||
"""
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
|
Reference in New Issue
Block a user