From 1aaeb596c63752071cbbaa2477a7f2406901b70b Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 6 Jan 2023 15:44:50 +0800 Subject: [PATCH] [example] gpt, shard init on all processes (#2366) --- colossalai/tensor/colo_tensor.py | 8 +++---- .../language/gpt/gemini/train_gpt_demo.py | 22 ++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 93ab982cc..3712d6a0a 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor): def set_process_group(self, pg: ProcessGroup): """set_process_group change the pg of the ColoTensor. Note that the valid use cases is limited. - Only existing pg is DP and dist spec is REPLICaTE is valid. + It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica. Args: pg (ProcessGroup): target pg @@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor): # if the new pg is the same as the old pg, just returns if self.process_group == pg: return - assert self.process_group.tp_world_size() == 1, \ - "Can not set_process_group on a ColoTensor whose process_group has tp world group" + assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \ + "Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1" assert self.dist_spec.placement.value == 'r', \ - "Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE" + "Can not set_process_group on a ColoTensor whose dist spec is not Replica" self.process_group = pg diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 14200bff7..29f8c8ef1 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): """ for mn, module in model.named_modules(): for pn, param in module.named_parameters(recurse=False): - # NOTE() a param maybe shared by tow modules + # NOTE() a param maybe shared by two modules if hasattr(param, 'visited'): continue + + # if shard init, then convert param to replica and use the dp-only ProcessGroup + param: ColoParameter = param param.set_dist_spec(ReplicaSpec()) + param.set_process_group(pg) + + # shard it w.r.t tp pattern if 'mlp.c_fc' in mn: if 'weight' in pn or 'bias' in pn: split_param_col_tp1d(param, pg) # colmn slice @@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): split_param_col_tp1d(param, pg) # colmn slice else: param.set_dist_spec(ReplicaSpec()) - param.visited = True @@ -248,27 +253,28 @@ def main(): torch.manual_seed(123) if args.distplan == "colossalai": # all param must use the same process group. - default_pg = ProcessGroup(tp_degree=args.tp_degree) - default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None # build GPT model if version.parse(CAI_VERSION) > version.parse("0.1.10"): with ColoInitContext(device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, - default_pg=default_pg): + default_pg=shard_pg): model = model_builder(args.model_type)(checkpoint=True) else: with ColoInitContext(device=get_current_device()): model = model_builder(args.model_type)(checkpoint=True) - pg = default_pg + tp_pg = ProcessGroup(tp_degree=args.tp_degree) # Tensor Parallelism (TP) - tensor_parallelize(model, pg) + tensor_parallelize(model, tp_pg) # build a Gemini model and a highly optimized cpu optimizer # Gemini + ZeRO DP, Note it must be used after TP - model, optimizer = build_gemini(model, pg, args.placement) + model, optimizer = build_gemini(model, tp_pg, args.placement) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) else: