mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
upgrade colossal-chat support tp_group>1, add sp for sft
This commit is contained in:
@@ -14,13 +14,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
import inspect
|
||||
import sys
|
||||
import torch.distributed as dist
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
def train(args):
|
||||
print(colossalai.__version__, inspect.getfile(colossalai))
|
||||
print(sys.executable)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -35,6 +42,38 @@ def train(args):
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True)
|
||||
# check if the hybrid parallel plugin is compatible with the model
|
||||
# try:
|
||||
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
# policy = get_autopolicy(model)
|
||||
# if policy is not None:
|
||||
# if args.plugin in ['zero2', 'zero2_cpu']:
|
||||
# # if compatible, set the plugin to hybrid, which use colo-attention
|
||||
# args.plugin = 'hybrid'
|
||||
# args.zero_stage = 2
|
||||
# if args.plugin == 'zero2_cpu':
|
||||
# args.zero_cpu_offload = True
|
||||
# else:
|
||||
# args.zero_cpu_offload = False
|
||||
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
|
||||
# except NotImplementedError:
|
||||
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
||||
# if args.use_flash_attn:
|
||||
# del model
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.pretrain,
|
||||
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
# attn_implementation="flash_attention_2",
|
||||
# trust_remote_code=True
|
||||
# )
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
Default torch ddp plugin without any acceleration, for
|
||||
@@ -47,7 +86,8 @@ def train(args):
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=True,
|
||||
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
|
||||
enable_flash_attention=args.use_flash_attn
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
@@ -55,6 +95,7 @@ def train(args):
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
@@ -71,11 +112,16 @@ def train(args):
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
elif args.plugin == "hybrid":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
@@ -93,20 +139,7 @@ def train(args):
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
init_ctx = nullcontext()
|
||||
with init_ctx:
|
||||
if args.use_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
use_flash_attention_2=True,
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
@@ -131,6 +164,7 @@ def train(args):
|
||||
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||
@@ -156,8 +190,13 @@ def train(args):
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
tp_size=args.tp,
|
||||
)
|
||||
# print(len(train_dataloader))
|
||||
# for batch in train_dataloader:
|
||||
# print(dist.get_rank(), tokenizer.batch_decode(batch["input_ids"]))
|
||||
# break
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
@@ -255,7 +294,7 @@ def train(args):
|
||||
# save model checkpoint after fitting on only rank0
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
|
||||
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
@@ -270,13 +309,18 @@ if __name__ == "__main__":
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
@@ -287,7 +331,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
|
Reference in New Issue
Block a user