mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 13:59:23 +00:00 
			
		
		
		
	* [application] add lora sft example * update requirements * update readme * update comment * update ci
		
			
				
	
	
		
			456 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			456 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| """
 | |
| Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task.
 | |
| """
 | |
| 
 | |
| import argparse
 | |
| import json
 | |
| import os
 | |
| import resource
 | |
| from contextlib import nullcontext
 | |
| from types import MethodType
 | |
| 
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| from coati.dataset.loader import RawConversationDataset
 | |
| from peft import LoraConfig
 | |
| from tqdm import tqdm
 | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 | |
| 
 | |
| import colossalai
 | |
| from colossalai.accelerator import get_accelerator
 | |
| from colossalai.booster import Booster
 | |
| from colossalai.booster.plugin import (
 | |
|     GeminiPlugin,
 | |
|     HybridParallelPlugin,
 | |
|     LowLevelZeroPlugin,
 | |
|     MoeHybridParallelPlugin,
 | |
|     Plugin,
 | |
|     TorchDDPPlugin,
 | |
| )
 | |
| from colossalai.cluster import DistCoordinator
 | |
| from colossalai.lazy import LazyInitContext
 | |
| from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 | |
| from colossalai.nn.optimizer import HybridAdam
 | |
| from colossalai.utils import get_current_device
 | |
| 
 | |
| 
 | |
| def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
 | |
|     loss = loss.data
 | |
|     group = getattr(plugin, "dp_group", None)
 | |
|     dist.all_reduce(loss, group=group)
 | |
|     return loss / dist.get_world_size(group)
 | |
| 
 | |
| 
 | |
| def train(args) -> None:
 | |
|     # ==============================
 | |
|     # Initialize Distributed Training
 | |
|     # ==============================
 | |
|     colossalai.launch_from_torch()
 | |
|     accelerator = get_accelerator()
 | |
|     coordinator = DistCoordinator()
 | |
| 
 | |
|     # ==============================
 | |
|     # Initialize Booster
 | |
|     # ==============================
 | |
|     if args.plugin == "ddp":
 | |
|         plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
 | |
|     elif args.plugin == "gemini":
 | |
|         plugin = GeminiPlugin(
 | |
|             precision=args.mixed_precision,
 | |
|             initial_scale=2**16,
 | |
|             max_norm=args.grad_clip,
 | |
|             enable_gradient_accumulation=(args.accumulation_steps > 1),
 | |
|             enable_fused_normalization=get_accelerator().is_available(),
 | |
|             enable_flash_attention=args.use_flash_attn,
 | |
|         )
 | |
|     elif args.plugin == "gemini_auto":
 | |
|         plugin = GeminiPlugin(
 | |
|             precision=args.mixed_precision,
 | |
|             placement_policy="auto",
 | |
|             initial_scale=2**16,
 | |
|             max_norm=args.grad_clip,
 | |
|             enable_gradient_accumulation=(args.accumulation_steps > 1),
 | |
|             enable_fused_normalization=get_accelerator().is_available(),
 | |
|             enable_flash_attention=args.use_flash_attn,
 | |
|         )
 | |
|     elif args.plugin == "zero2":
 | |
|         plugin = LowLevelZeroPlugin(
 | |
|             stage=2,
 | |
|             precision=args.mixed_precision,
 | |
|             initial_scale=2**16,
 | |
|             max_norm=args.grad_clip,
 | |
|         )
 | |
|     elif args.plugin == "zero2_cpu":
 | |
|         plugin = LowLevelZeroPlugin(
 | |
|             stage=2,
 | |
|             precision=args.mixed_precision,
 | |
|             initial_scale=2**16,
 | |
|             cpu_offload=True,
 | |
|             max_norm=args.grad_clip,
 | |
|         )
 | |
|     elif args.plugin == "3d":
 | |
|         plugin = HybridParallelPlugin(
 | |
|             tp_size=args.tp,
 | |
|             pp_size=args.pp,
 | |
|             sp_size=args.sp,
 | |
|             sequence_parallelism_mode=args.sp_mode,
 | |
|             zero_stage=args.zero_stage,
 | |
|             enable_flash_attention=args.use_flash_attn,
 | |
|             enable_fused_normalization=get_accelerator().is_available(),
 | |
|             enable_sequence_parallelism=args.enable_sequence_parallelism,
 | |
|             cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
 | |
|             max_norm=args.grad_clip,
 | |
|             precision=args.mixed_precision,
 | |
|             microbatch_size=args.microbatch_size,
 | |
|         )
 | |
|     elif args.plugin == "moe":
 | |
|         plugin = MoeHybridParallelPlugin(
 | |
|             ep_size=args.ep,
 | |
|             tp_size=args.tp,
 | |
|             pp_size=args.pp,
 | |
|             zero_stage=args.zero_stage,
 | |
|             sp_size=args.sp,
 | |
|             sequence_parallelism_mode=args.sp_mode,
 | |
|             enable_sequence_parallelism=args.sp > 1,
 | |
|             enable_fused_normalization=get_accelerator().is_available(),
 | |
|             enable_flash_attention=args.use_flash_attn,
 | |
|             max_norm=args.grad_clip,
 | |
|             precision=args.mixed_precision,
 | |
|             microbatch_size=args.microbatch_size,
 | |
|         )
 | |
|     else:
 | |
|         raise ValueError(f"Unknown plugin {args.plugin}")
 | |
| 
 | |
|     booster = Booster(plugin=plugin)
 | |
| 
 | |
|     def is_master():
 | |
|         if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
 | |
|             return coordinator.rank == coordinator.world_size - 1
 | |
|         return coordinator.is_master()
 | |
| 
 | |
|     # ==============================
 | |
|     # Initialize Tensorboard and Save Config
 | |
|     # ==============================
 | |
|     if is_master():
 | |
|         if args.tensorboard_dir is not None:
 | |
|             from torch.utils.tensorboard import SummaryWriter
 | |
| 
 | |
|             os.makedirs(args.tensorboard_dir, exist_ok=True)
 | |
|             writer = SummaryWriter(args.tensorboard_dir)
 | |
| 
 | |
|         with open(args.config_file, "w") as f:
 | |
|             json.dump(args.__dict__, f, indent=4)
 | |
| 
 | |
|     # ======================================================
 | |
|     # Initialize Tokenizer, Dataset, Collator and Dataloader
 | |
|     # ======================================================
 | |
|     tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
 | |
| 
 | |
|     coordinator.print_on_master(
 | |
|         f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
 | |
|     )
 | |
| 
 | |
|     coordinator.print_on_master(f"Load dataset: {args.dataset}")
 | |
|     dataset = RawConversationDataset(
 | |
|         tokenizer,
 | |
|         args.dataset,
 | |
|         args.max_length,
 | |
|     )
 | |
| 
 | |
|     dataloader = plugin.prepare_dataloader(
 | |
|         dataset=dataset,
 | |
|         batch_size=args.batch_size,
 | |
|         shuffle=True,
 | |
|         drop_last=True,
 | |
|     )
 | |
| 
 | |
|     coordinator.print_on_master(
 | |
|         f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
 | |
|     )
 | |
| 
 | |
|     # ======================================================
 | |
|     # Initialize Model, Objective, Optimizer and LR Scheduler
 | |
|     # ======================================================
 | |
|     # When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
 | |
|     init_ctx = (
 | |
|         LazyInitContext(default_device=get_current_device())
 | |
|         if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
 | |
|         else nullcontext()
 | |
|     )
 | |
|     attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
 | |
| 
 | |
|     config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True)
 | |
| 
 | |
|     with init_ctx:
 | |
|         # from_pretrained is not compatible with LoRA, we load pretrained weights later.
 | |
|         # model = AutoModelForCausalLM.from_pretrained(
 | |
|         #     args.pretrained,
 | |
|         #     torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
 | |
|         #     trust_remote_code=True,
 | |
|         #     attn_implementation=attn_impl,
 | |
|         # )
 | |
|         model = AutoModelForCausalLM.from_config(
 | |
|             config,
 | |
|             trust_remote_code=True,
 | |
|             attn_implementation=attn_impl,
 | |
|             torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
 | |
|         )
 | |
| 
 | |
|         if args.lora_rank > 0:
 | |
|             if model.__class__.__name__.startswith("DeepseekV3"):
 | |
|                 lora_config = LoraConfig(
 | |
|                     task_type="CAUSAL_LM",
 | |
|                     r=args.lora_rank,
 | |
|                     lora_alpha=args.lora_alpha,
 | |
|                     target_modules=["gate_proj", "up_proj", "down_proj"],
 | |
|                 )
 | |
|             else:
 | |
|                 lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
 | |
|             model = booster.enable_lora(model, lora_config=lora_config)
 | |
| 
 | |
|     # this is essential, otherwise the grad checkpoint will not work.
 | |
|     model.train()
 | |
| 
 | |
|     if args.use_grad_checkpoint:
 | |
|         model.gradient_checkpointing_enable()
 | |
|         coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
 | |
|     if model.config.__class__.__name__.startswith("DeepseekV3"):
 | |
|         model.config.use_cache = False
 | |
|         model.eval()
 | |
|         # enable grad for moe layers
 | |
|         for m in model.modules():
 | |
|             if m.__class__.__name__ == "DeepseekV3MoE":
 | |
|                 m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
 | |
| 
 | |
|     model_numel = sum(p.numel() for p in model.parameters())
 | |
|     coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B")
 | |
| 
 | |
|     optimizer = HybridAdam(
 | |
|         model_params=model.parameters(),
 | |
|         lr=args.lr,
 | |
|         betas=(0.9, 0.95),
 | |
|         weight_decay=args.weight_decay,
 | |
|         adamw_mode=True,
 | |
|     )
 | |
| 
 | |
|     if args.warmup_steps is None:
 | |
|         args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
 | |
|         coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
 | |
| 
 | |
|     lr_scheduler = CosineAnnealingWarmupLR(
 | |
|         optimizer=optimizer,
 | |
|         total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
 | |
|         warmup_steps=args.warmup_steps,
 | |
|         eta_min=0.1 * args.lr,
 | |
|     )
 | |
| 
 | |
|     # Flash attention will be disabled because it does NOT support fp32.
 | |
|     default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
 | |
|     torch.set_default_dtype(default_dtype)
 | |
|     model, optimizer, _, dataloader, lr_scheduler = booster.boost(
 | |
|         model=model,
 | |
|         optimizer=optimizer,
 | |
|         lr_scheduler=lr_scheduler,
 | |
|         dataloader=dataloader,
 | |
|     )
 | |
| 
 | |
|     torch.set_default_dtype(torch.float)
 | |
|     booster.load_model(model, args.pretrained)
 | |
| 
 | |
|     coordinator.print_on_master(
 | |
|         f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
 | |
|     )
 | |
|     coordinator.print_on_master(
 | |
|         f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
 | |
|     )
 | |
| 
 | |
|     start_epoch = 0
 | |
|     start_step = 0
 | |
| 
 | |
|     num_steps_per_epoch = len(dataloader) // args.accumulation_steps
 | |
| 
 | |
|     for epoch in range(start_epoch, args.num_epochs):
 | |
|         dataloader.sampler.set_epoch(epoch=epoch)
 | |
|         if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
 | |
|             data_iter = iter(dataloader)
 | |
|             step_bar = tqdm(
 | |
|                 range(len(dataloader)),
 | |
|                 desc="Step",
 | |
|                 disable=not is_master(),
 | |
|             )
 | |
|             for step in step_bar:
 | |
|                 outputs = booster.execute_pipeline(
 | |
|                     data_iter,
 | |
|                     model,
 | |
|                     criterion=lambda outputs, inputs: outputs[0],
 | |
|                     optimizer=optimizer,
 | |
|                     return_loss=True,
 | |
|                 )
 | |
|                 loss = outputs["loss"]
 | |
|                 if booster.plugin.stage_manager.is_last_stage():
 | |
|                     global_loss = all_reduce_mean(loss, plugin)
 | |
| 
 | |
|                 optimizer.step()
 | |
| 
 | |
|                 if booster.plugin.stage_manager.is_last_stage():
 | |
|                     grad_norm = optimizer.get_grad_norm()
 | |
|                     step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
 | |
| 
 | |
|                 if args.tensorboard_dir is not None and is_master():
 | |
|                     global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
 | |
|                     writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step)
 | |
|                     writer.add_scalar(
 | |
|                         tag="Learning Rate",
 | |
|                         scalar_value=lr_scheduler.get_last_lr()[0],
 | |
|                         global_step=global_step,
 | |
|                     )
 | |
|                     writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
 | |
| 
 | |
|                 lr_scheduler.step()
 | |
|                 optimizer.zero_grad()
 | |
| 
 | |
|         else:
 | |
|             pbar = tqdm(
 | |
|                 dataloader,
 | |
|                 desc=f"Epoch {epoch}",
 | |
|                 disable=not is_master(),
 | |
|                 initial=start_step // args.accumulation_steps,
 | |
|             )
 | |
|             total_loss = torch.tensor(0.0, device=get_current_device())
 | |
|             for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
 | |
|                 batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
 | |
| 
 | |
|                 batch_output = model(**batch)
 | |
| 
 | |
|                 loss = batch_output.loss / args.accumulation_steps
 | |
|                 total_loss.add_(loss.data)
 | |
| 
 | |
|                 booster.backward(loss=loss, optimizer=optimizer)
 | |
| 
 | |
|                 if (step + 1) % args.accumulation_steps == 0:
 | |
|                     all_reduce_mean(total_loss, plugin)
 | |
| 
 | |
|                     optimizer.step()
 | |
| 
 | |
|                     grad_norm = optimizer.get_grad_norm()
 | |
|                     pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm})
 | |
|                     if args.tensorboard_dir is not None and is_master():
 | |
|                         global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
 | |
|                         writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
 | |
|                         writer.add_scalar(
 | |
|                             tag="Learning Rate",
 | |
|                             scalar_value=lr_scheduler.get_last_lr()[0],
 | |
|                             global_step=global_step,
 | |
|                         )
 | |
|                         writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
 | |
| 
 | |
|                     lr_scheduler.step()
 | |
|                     optimizer.zero_grad()
 | |
| 
 | |
|                     total_loss.fill_(0.0)
 | |
| 
 | |
|         # Delete cache.
 | |
|         # del batch, batch_labels, batch_output, loss
 | |
|         accelerator.empty_cache()
 | |
| 
 | |
|     # Final save.
 | |
|     coordinator.print_on_master("Start saving final model checkpoint")
 | |
|     if args.lora_rank > 0:
 | |
|         booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora"))
 | |
|     else:
 | |
|         booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
 | |
|     coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
 | |
| 
 | |
|     coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser()
 | |
|     # Basic training information.
 | |
|     parser.add_argument(
 | |
|         "-m",
 | |
|         "--pretrained",
 | |
|         type=str,
 | |
|         required=True,
 | |
|         help="Address of the pre-trained model",
 | |
|     )
 | |
|     parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
 | |
|     parser.add_argument(
 | |
|         "-p",
 | |
|         "--plugin",
 | |
|         type=str,
 | |
|         default="zero2",
 | |
|         choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
 | |
|         help="Choose which plugin to use",
 | |
|     )
 | |
|     parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
 | |
|     parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
 | |
|     parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
 | |
|     # Training parameters
 | |
|     parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
 | |
|     parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
 | |
|     parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
 | |
|     parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
 | |
|     parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
 | |
|     parser.add_argument(
 | |
|         "--mixed_precision",
 | |
|         type=str,
 | |
|         default="bf16",
 | |
|         choices=["fp16", "bf16"],
 | |
|         help="Mixed precision",
 | |
|     )
 | |
|     parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
 | |
|     parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
 | |
|     parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
 | |
|     parser.add_argument(
 | |
|         "-g",
 | |
|         "--use_grad_checkpoint",
 | |
|         action="store_true",
 | |
|         default=False,
 | |
|         help="Use gradient checkpointing",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "-f",
 | |
|         "--use_flash_attn",
 | |
|         action="store_true",
 | |
|         default=False,
 | |
|         help="Use flash-attention",
 | |
|     )
 | |
| 
 | |
|     # Additional arguments for 3d plugin.
 | |
|     parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
 | |
|     parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
 | |
|     parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
 | |
|     parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
 | |
|     parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
 | |
|     parser.add_argument(
 | |
|         "--sp_mode",
 | |
|         type=str,
 | |
|         default="split_gather",
 | |
|         choices=["split_gather", "ring", "all_to_all"],
 | |
|         help="SP mode, used for 3d plugin.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--enable_sequence_parallelism",
 | |
|         default=False,
 | |
|         action="store_true",
 | |
|         help="Whether to enable SP, used for 3d plugin.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
 | |
|     )
 | |
|     parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
 | |
|     parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1:
 | |
|         raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.")
 | |
| 
 | |
|     train(args)
 |