modefied the pp build for ckpt adaptation (#803)

This commit is contained in:
LuGY
2022-04-24 12:23:16 +08:00
committed by GitHub
parent 8789850eea
commit c1e8d2001e
2 changed files with 6 additions and 3 deletions

View File

@@ -240,7 +240,6 @@ def build_pipeline_model_from_cfg(config,
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
"""An intializer to split the model into different stages for pipeline parallelism.
Note that `layer` must be `torch.nn.Sequential`.
Args:
layers (`torch.nn.Sequential`): Layers of model
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
@@ -252,7 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(nn.Sequential(*layers[start:end]))
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
*layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)]))
if verbose:
logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0])
@@ -263,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]