mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-09 20:14:29 +00:00
Merge branch 'main' into sync/npu
This commit is contained in:
@@ -71,9 +71,10 @@ def main():
|
||||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=0)
|
||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
@@ -92,9 +93,17 @@ def main():
|
||||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
plugin = GeminiPlugin(
|
||||
placement_policy="auto",
|
||||
precision="bf16",
|
||||
warmup_non_model_data_ratio=args.warmup_ratio,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
@@ -129,9 +138,11 @@ def main():
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
pp_style="interleaved",
|
||||
zero_stage=args.zero,
|
||||
num_model_chunks=2,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
@@ -141,7 +152,7 @@ def main():
|
||||
zero_stage=args.zero,
|
||||
cpu_offload=True,
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
num_microbatches=args.mbs,
|
||||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
)
|
||||
|
Reference in New Issue
Block a user