[example] gpt, shard init on all processes (#2366)

This commit is contained in:
Jiarui Fang
2023-01-06 15:44:50 +08:00
committed by GitHub
parent 1f8ab6f1f5
commit 1aaeb596c6
2 changed files with 18 additions and 12 deletions

View File

@@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by tow modules
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
@@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_col_tp1d(param, pg) # colmn slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
@@ -248,27 +253,28 @@ def main():
torch.manual_seed(123)
if args.distplan == "colossalai":
# all param must use the same process group.
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build GPT model
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=default_pg):
default_pg=shard_pg):
model = model_builder(args.model_type)(checkpoint=True)
else:
with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True)
pg = default_pg
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
model, optimizer = build_gemini(model, pg, args.placement)
model, optimizer = build_gemini(model, tp_pg, args.placement)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else: