Optimize pipeline schedule (#94)

* add pipeline shared module wrapper and update load batch

* added model parallel process group for amp and clip grad (#86)

* added model parallel process group for amp and clip grad

* update amp and clip with model parallel process group

* remove pipeline_prev/next group (#88)

* micro batch offload

* optimize pipeline gpu memory usage

* pipeline can receive tensor shape (#93)

* optimize pipeline gpu memory usage

* fix grad accumulation step counter

* rename classes and functions

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
ver217
2021-12-30 15:56:46 +08:00
committed by GitHub
parent e5b9f9a08d
commit 96780e6ee4
29 changed files with 423 additions and 290 deletions

View File

@@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank):
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
def check_model_parallel_rank(rank):
for i in range(8):
if rank in [i, i+8]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank):
if rank in [0, 4, 8, 12]:
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
@@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host):
check_data_parallel_rank(rank)
check_2d_parallel_rank(rank)
check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy()
torch.cuda.empty_cache()

View File

@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
assert ppr == 1
def check_model_parallel_rank(rank):
for i in range(16):
if rank in [i, i+16]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
@@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host):
check_pipeline_parallel_rank(rank)
check_tensor_parallel_rank(rank)
check_2p5d_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy()
torch.cuda.empty_cache()

View File

@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
assert ppr == 1
def check_model_parallel_rank(rank):
for i in range(16):
if rank in [i, i+16]:
assert gpc.get_local_rank(ParallelMode.MODEL) == i
def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
@@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host):
check_3d_parallel_rank(rank)
check_data_parallel_rank(rank)
check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy()
torch.cuda.empty_cache()

View File

@@ -23,7 +23,7 @@ BATCH_SIZE = 16
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.TORCH),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)

View File

@@ -75,40 +75,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
rank, check_equal(grad, output_grad)))
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger):
dtype = torch.float32
device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad)
if rank % 2 == 0:
need_meta = True
need_meta = send_tensor_meta(tensor, need_meta)
logger.info('Rank {} shape sent (need meta: {}).'.format(
rank, need_meta))
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
req.wait()
out = tensor.clone()
logger.info('Rank {} test op: tensor sent.'.format(rank))
else:
recv_tensor_shape = recv_tensor_meta(None)
logger.info('Rank {} shape received. Correct shape: {}'.format(
rank, tensor_shape == recv_tensor_shape))
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
req.wait()
logger.info('Rank {} test op: received tensor ({})'.format(
rank, out.shape))
logger.info('Rank {} test op. Correct tensor: {}'.format(
rank, check_equal(tensor, out)))
def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
def check_comm(size, rank, prev_rank, next_rank, logger):
dtype = torch.float32
device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
@@ -117,7 +84,6 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad)
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
check_forward(tensor, rank, logger)
check_backward(grad, rank, logger)
check_forward_backward(tensor, grad, rank, logger)
@@ -135,18 +101,13 @@ def run_check(rank, world_size, port):
logger = get_dist_logger()
rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
logger.info(
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format(
rank, prev_rank, up_ranks, next_rank, down_ranks))
'Rank {0}: prev rank {1}, next rank {2}'.format(
rank, prev_rank, next_rank))
logger.info('Distributed environment is initialzied.')
check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group,
logger)
check_comm(world_size, rank, prev_rank, next_rank, logger)
gpc.destroy()
torch.cuda.empty_cache()