From f7fd592bf470a7190177af298709d2d59cca4596 Mon Sep 17 00:00:00 2001 From: ZijianYY <119492445+ZijianYY@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:57:50 +0800 Subject: [PATCH] [examples]adding tp to PaLM (#2319) --- examples/language/palm/train.py | 44 ++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 89b4e058f..7c080b7f3 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -104,6 +104,48 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: raise NotImplemented(f"CAI version {cai_version} is not supported") return model +## Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + if hasattr(param, 'visited'): + continue + param.set_dist_spec(ReplicaSpec()) + if 'net.0' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'to_q' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'to_kv' in mn: + split_param_row_tp1d(param, pg) # row slice + elif 'to_out' in mn: + split_param_row_tp1d(param, pg) # row slice + elif '1.1' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif '1.2' in mn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + + param.visited = True + args = parse_args() if args.distplan not in ["colossalai", "pytorch"]: @@ -150,7 +192,7 @@ if args.distplan == "colossalai": model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) pg = default_pg - #tensor_parallelize(model, pg) + tensor_parallelize(model, pg) model = gemini_zero_dpp(model, pg, args.placement) #optimizer