mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
Optimize pipeline schedule (#94)
* add pipeline shared module wrapper and update load batch * added model parallel process group for amp and clip grad (#86) * added model parallel process group for amp and clip grad * update amp and clip with model parallel process group * remove pipeline_prev/next group (#88) * micro batch offload * optimize pipeline gpu memory usage * pipeline can receive tensor shape (#93) * optimize pipeline gpu memory usage * fix grad accumulation step counter * rename classes and functions Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
@@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank):
|
||||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
|
||||
|
||||
|
||||
def check_model_parallel_rank(rank):
|
||||
for i in range(8):
|
||||
if rank in [i, i+8]:
|
||||
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
||||
|
||||
|
||||
def check_tensor_parallel_rank(rank):
|
||||
if rank in [0, 4, 8, 12]:
|
||||
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
@@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host):
|
||||
check_data_parallel_rank(rank)
|
||||
check_2d_parallel_rank(rank)
|
||||
check_pipeline_parallel_rank(rank)
|
||||
check_model_parallel_rank(rank)
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
|
||||
assert ppr == 1
|
||||
|
||||
|
||||
def check_model_parallel_rank(rank):
|
||||
for i in range(16):
|
||||
if rank in [i, i+16]:
|
||||
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
||||
|
||||
|
||||
def check_tensor_parallel_rank(rank):
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
|
||||
@@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host):
|
||||
check_pipeline_parallel_rank(rank)
|
||||
check_tensor_parallel_rank(rank)
|
||||
check_2p5d_parallel_rank(rank)
|
||||
check_model_parallel_rank(rank)
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
|
||||
assert ppr == 1
|
||||
|
||||
|
||||
def check_model_parallel_rank(rank):
|
||||
for i in range(16):
|
||||
if rank in [i, i+16]:
|
||||
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
||||
|
||||
|
||||
def check_tensor_parallel_rank(rank):
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
|
||||
@@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host):
|
||||
check_3d_parallel_rank(rank)
|
||||
check_data_parallel_rank(rank)
|
||||
check_pipeline_parallel_rank(rank)
|
||||
check_model_parallel_rank(rank)
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -23,7 +23,7 @@ BATCH_SIZE = 16
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.TORCH),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
|
@@ -75,40 +75,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
|
||||
rank, check_equal(grad, output_grad)))
|
||||
|
||||
|
||||
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(tensor)
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(grad)
|
||||
if rank % 2 == 0:
|
||||
need_meta = True
|
||||
need_meta = send_tensor_meta(tensor, need_meta)
|
||||
logger.info('Rank {} shape sent (need meta: {}).'.format(
|
||||
rank, need_meta))
|
||||
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
|
||||
req.wait()
|
||||
out = tensor.clone()
|
||||
logger.info('Rank {} test op: tensor sent.'.format(rank))
|
||||
else:
|
||||
recv_tensor_shape = recv_tensor_meta(None)
|
||||
logger.info('Rank {} shape received. Correct shape: {}'.format(
|
||||
rank, tensor_shape == recv_tensor_shape))
|
||||
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
|
||||
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
|
||||
req.wait()
|
||||
logger.info('Rank {} test op: received tensor ({})'.format(
|
||||
rank, out.shape))
|
||||
|
||||
logger.info('Rank {} test op. Correct tensor: {}'.format(
|
||||
rank, check_equal(tensor, out)))
|
||||
|
||||
|
||||
def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
def check_comm(size, rank, prev_rank, next_rank, logger):
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
@@ -117,7 +84,6 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
dist.all_reduce(tensor)
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(grad)
|
||||
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
|
||||
check_forward(tensor, rank, logger)
|
||||
check_backward(grad, rank, logger)
|
||||
check_forward_backward(tensor, grad, rank, logger)
|
||||
@@ -135,18 +101,13 @@ def run_check(rank, world_size, port):
|
||||
logger = get_dist_logger()
|
||||
rank = gpc.get_global_rank()
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
logger.info(
|
||||
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format(
|
||||
rank, prev_rank, up_ranks, next_rank, down_ranks))
|
||||
'Rank {0}: prev rank {1}, next rank {2}'.format(
|
||||
rank, prev_rank, next_rank))
|
||||
logger.info('Distributed environment is initialzied.')
|
||||
|
||||
check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group,
|
||||
logger)
|
||||
check_comm(world_size, rank, prev_rank, next_rank, logger)
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user