mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode * [polish] add comments for strict ddp mode * [zero] fix test error
This commit is contained in:
@@ -187,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
|
||||
fp16_init_scale = 2**5
|
||||
gpu_margin_mem_ratio_for_auto = 0
|
||||
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
model = GeminiDDP(model,
|
||||
strict_ddp_mode=ddp_flag,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.n_embd,
|
||||
search_range_mb=64)
|
||||
search_range_mb=128)
|
||||
# configure the const policy
|
||||
if placement_policy == 'const':
|
||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||
@@ -279,11 +280,12 @@ def main():
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
tensor_parallelize(model, tp_pg)
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# build a Gemini model and a highly optimized cpu optimizer
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement)
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
|
Reference in New Issue
Block a user