mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[application] add lora sft example (#6192)
* [application] add lora sft example * update requirements * update readme * update comment * update ci
This commit is contained in:
@@ -0,0 +1,455 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset.loader import RawConversationDataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import (
|
||||
GeminiPlugin,
|
||||
HybridParallelPlugin,
|
||||
LowLevelZeroPlugin,
|
||||
MoeHybridParallelPlugin,
|
||||
Plugin,
|
||||
TorchDDPPlugin,
|
||||
)
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
|
||||
loss = loss.data
|
||||
group = getattr(plugin, "dp_group", None)
|
||||
dist.all_reduce(loss, group=group)
|
||||
return loss / dist.get_world_size(group)
|
||||
|
||||
|
||||
def train(args) -> None:
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
accelerator = get_accelerator()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "ddp":
|
||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
enable_fused_normalization=get_accelerator().is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
enable_fused_normalization=get_accelerator().is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_fused_normalization=get_accelerator().is_available(),
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
elif args.plugin == "moe":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
ep_size=args.ep,
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
zero_stage=args.zero_stage,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
enable_sequence_parallelism=args.sp > 1,
|
||||
enable_fused_normalization=get_accelerator().is_available(),
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
def is_master():
|
||||
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
|
||||
return coordinator.rank == coordinator.world_size - 1
|
||||
return coordinator.is_master()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard and Save Config
|
||||
# ==============================
|
||||
if is_master():
|
||||
if args.tensorboard_dir is not None:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
|
||||
)
|
||||
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
dataset = RawConversationDataset(
|
||||
tokenizer,
|
||||
args.dataset,
|
||||
args.max_length,
|
||||
)
|
||||
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
|
||||
|
||||
config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||
|
||||
with init_ctx:
|
||||
# from_pretrained is not compatible with LoRA, we load pretrained weights later.
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.pretrained,
|
||||
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
# trust_remote_code=True,
|
||||
# attn_implementation=attn_impl,
|
||||
# )
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
trust_remote_code=True,
|
||||
attn_implementation=attn_impl,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
if model.__class__.__name__.startswith("DeepseekV3"):
|
||||
lora_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=["gate_proj", "up_proj", "down_proj"],
|
||||
)
|
||||
else:
|
||||
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
|
||||
model = booster.enable_lora(model, lora_config=lora_config)
|
||||
|
||||
# this is essential, otherwise the grad checkpoint will not work.
|
||||
model.train()
|
||||
|
||||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
if model.config.__class__.__name__.startswith("DeepseekV3"):
|
||||
model.config.use_cache = False
|
||||
model.eval()
|
||||
# enable grad for moe layers
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||
|
||||
model_numel = sum(p.numel() for p in model.parameters())
|
||||
coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B")
|
||||
|
||||
optimizer = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Flash attention will be disabled because it does NOT support fp32.
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
booster.load_model(model, args.pretrained)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
|
||||
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
|
||||
data_iter = iter(dataloader)
|
||||
step_bar = tqdm(
|
||||
range(len(dataloader)),
|
||||
desc="Step",
|
||||
disable=not is_master(),
|
||||
)
|
||||
for step in step_bar:
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if booster.plugin.stage_manager.is_last_stage():
|
||||
global_loss = all_reduce_mean(loss, plugin)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if booster.plugin.stage_manager.is_last_stage():
|
||||
grad_norm = optimizer.get_grad_norm()
|
||||
step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
|
||||
|
||||
if args.tensorboard_dir is not None and is_master():
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
|
||||
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
else:
|
||||
pbar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not is_master(),
|
||||
initial=start_step // args.accumulation_steps,
|
||||
)
|
||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||
for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
loss = batch_output.loss / args.accumulation_steps
|
||||
total_loss.add_(loss.data)
|
||||
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
all_reduce_mean(total_loss, plugin)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
grad_norm = optimizer.get_grad_norm()
|
||||
pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm})
|
||||
if args.tensorboard_dir is not None and is_master():
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
|
||||
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
total_loss.fill_(0.0)
|
||||
|
||||
# Delete cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
accelerator.empty_cache()
|
||||
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
if args.lora_rank > 0:
|
||||
booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora"))
|
||||
else:
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Basic training information.
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--pretrained",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Address of the pre-trained model",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="zero2",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
|
||||
# Training parameters
|
||||
parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Mixed precision",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--use_grad_checkpoint",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use gradient checkpointing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
|
||||
# Additional arguments for 3d plugin.
|
||||
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
|
||||
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
|
||||
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
|
||||
parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
|
||||
parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
|
||||
parser.add_argument(
|
||||
"--sp_mode",
|
||||
type=str,
|
||||
default="split_gather",
|
||||
choices=["split_gather", "ring", "all_to_all"],
|
||||
help="SP mode, used for 3d plugin.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_sequence_parallelism",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to enable SP, used for 3d plugin.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
|
||||
)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
|
||||
parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1:
|
||||
raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.")
|
||||
|
||||
train(args)
|
Reference in New Issue
Block a user