[Coati] Train DPO using PP (#6054)

* update dpo

* remove unsupport plugin

* update msg

* update dpo

* remove unsupport plugin

* update msg

* update template

* update dataset

* add pp for dpo

* update dpo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add dpo fn

* update dpo

* update dpo

* update dpo

* update dpo

* minor update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update loss

* update help

* polish code

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2024-10-11 19:32:00 +08:00
committed by GitHub
parent dc2cdaf3e8
commit 4c8e85ee0d
8 changed files with 529 additions and 236 deletions

View File

@@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -29,8 +29,6 @@ def train(args):
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
@@ -46,7 +44,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
@@ -56,14 +54,6 @@ def train(args):
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
@@ -92,20 +82,24 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
ref_plugin = HybridParallelPlugin(
tp_size=args.ref_tp,
pp_size=1,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
ref_booster = Booster(plugin=ref_plugin)
init_ctx = nullcontext()
with init_ctx:
@@ -130,6 +124,7 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
else:
ref_model = None
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
@@ -139,7 +134,9 @@ def train(args):
disable_dropout(ref_model)
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
# Make sure gradient checkpointing can be activated.
model.train()
# Note, for some models, lora may not be compatible with gradient checkpointing.
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
@@ -169,7 +166,7 @@ def train(args):
adamw_mode=True,
)
# configure dataset
# Configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
@@ -213,14 +210,15 @@ def train(args):
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
if ref_model is not None:
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@@ -312,7 +310,7 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
@@ -342,22 +340,35 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument(
"--microbatch_size",
type=int,
default=2,
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
)
# Parameter for reference model
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument(
"--ref_tp",
type=int,
default=1,
help="TP size for reference model; used only when reference model is too large.",
)
args = parser.parse_args()
# fool proof hyperparameter setup

View File

@@ -68,7 +68,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,