diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index ee0db250c..08e7550df 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -1,10 +1,8 @@ import argparse -import inspect import json import math import os import resource -import sys from contextlib import nullcontext import torch @@ -26,8 +24,6 @@ logger = get_dist_logger() def train(args): - print(colossalai.__version__, inspect.getfile(colossalai)) - print(sys.executable) # check lora compatibility if "gemini" in args.plugin and args.lora_rank > 0: raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") @@ -44,38 +40,19 @@ def train(args): # ============================== init_ctx = nullcontext() with init_ctx: - model = AutoModelForCausalLM.from_pretrained( - args.pretrain, - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - trust_remote_code=True, - ) - # check if the hybrid parallel plugin is compatible with the model - try: - from colossalai.shardformer.policies.auto_policy import get_autopolicy - - policy = get_autopolicy(model) - if policy is not None: - if args.plugin in ["zero2", "zero2_cpu"]: - # if compatible, set the plugin to hybrid, which use colo-attention - args.plugin = "3d" - args.zero_stage = 2 - if args.plugin == "zero2_cpu": - args.zero_cpu_offload = True - else: - args.zero_cpu_offload = False - logger.info( - f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}" - ) - except NotImplementedError: - logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead") if args.use_flash_attn: - del model model = AutoModelForCausalLM.from_pretrained( args.pretrain, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, attn_implementation="flash_attention_2", trust_remote_code=True, ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) if args.lora_rank > 0: model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)