mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
upgrade ppo dpo rm script
This commit is contained in:
@@ -48,29 +48,29 @@ def train(args):
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True)
|
||||
# check if the hybrid parallel plugin is compatible with the model
|
||||
# try:
|
||||
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
# policy = get_autopolicy(model)
|
||||
# if policy is not None:
|
||||
# if args.plugin in ['zero2', 'zero2_cpu']:
|
||||
# # if compatible, set the plugin to hybrid, which use colo-attention
|
||||
# args.plugin = 'hybrid'
|
||||
# args.zero_stage = 2
|
||||
# if args.plugin == 'zero2_cpu':
|
||||
# args.zero_cpu_offload = True
|
||||
# else:
|
||||
# args.zero_cpu_offload = False
|
||||
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
|
||||
# except NotImplementedError:
|
||||
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
||||
# if args.use_flash_attn:
|
||||
# del model
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
# args.pretrain,
|
||||
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
# attn_implementation="flash_attention_2",
|
||||
# trust_remote_code=True
|
||||
# )
|
||||
try:
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
policy = get_autopolicy(model)
|
||||
if policy is not None:
|
||||
if args.plugin in ['zero2', 'zero2_cpu']:
|
||||
# if compatible, set the plugin to hybrid, which use colo-attention
|
||||
args.plugin = '3d'
|
||||
args.zero_stage = 2
|
||||
if args.plugin == 'zero2_cpu':
|
||||
args.zero_cpu_offload = True
|
||||
else:
|
||||
args.zero_cpu_offload = False
|
||||
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
|
||||
except NotImplementedError:
|
||||
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
||||
if args.use_flash_attn:
|
||||
del model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
@@ -112,7 +112,7 @@ def train(args):
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "hybrid":
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
@@ -224,7 +224,6 @@ def train(args):
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
# model = model.to(get_current_device())
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
@@ -309,7 +308,7 @@ if __name__ == "__main__":
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
|
||||
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
|
Reference in New Issue
Block a user