upgrade colossal-chat support tp_group>1, add sp for sft

This commit is contained in:
YeAnbang
2024-05-27 05:55:57 +00:00
parent 73e88a5553
commit 7a7e86987d
33 changed files with 7574 additions and 105 deletions

View File

@@ -14,13 +14,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
import inspect
import sys
import torch.distributed as dist
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def train(args):
print(colossalai.__version__, inspect.getfile(colossalai))
print(sys.executable)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
@@ -35,6 +42,38 @@ def train(args):
# ==============================
# Initialize Booster
# ==============================
init_ctx = nullcontext()
with init_ctx:
model = AutoModelForCausalLM.from_pretrained(args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True)
# check if the hybrid parallel plugin is compatible with the model
# try:
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
# policy = get_autopolicy(model)
# if policy is not None:
# if args.plugin in ['zero2', 'zero2_cpu']:
# # if compatible, set the plugin to hybrid, which use colo-attention
# args.plugin = 'hybrid'
# args.zero_stage = 2
# if args.plugin == 'zero2_cpu':
# args.zero_cpu_offload = True
# else:
# args.zero_cpu_offload = False
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
# except NotImplementedError:
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
# if args.use_flash_attn:
# del model
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrain,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# attn_implementation="flash_attention_2",
# trust_remote_code=True
# )
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
@@ -47,7 +86,8 @@ def train(args):
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -55,6 +95,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -71,11 +112,16 @@ def train(args):
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
elif args.plugin == "hybrid":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
@@ -93,20 +139,7 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing
model.gradient_checkpointing_enable()
@@ -131,6 +164,7 @@ def train(args):
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "right"
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
@@ -156,8 +190,13 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
# print(len(train_dataloader))
# for batch in train_dataloader:
# print(dist.get_rank(), tokenizer.batch_decode(batch["input_ids"]))
# break
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
@@ -255,7 +294,7 @@ def train(args):
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
# booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}")
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
@@ -270,13 +309,18 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"],
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
@@ -287,7 +331,7 @@ if __name__ == "__main__":
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",