mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
Optimize pipeline schedule (#94)
* add pipeline shared module wrapper and update load batch * added model parallel process group for amp and clip grad (#86) * added model parallel process group for amp and clip grad * update amp and clip with model parallel process group * remove pipeline_prev/next group (#88) * micro batch offload * optimize pipeline gpu memory usage * pipeline can receive tensor shape (#93) * optimize pipeline gpu memory usage * fix grad accumulation step counter * rename classes and functions Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
@@ -75,40 +75,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
|
||||
rank, check_equal(grad, output_grad)))
|
||||
|
||||
|
||||
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(tensor)
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(grad)
|
||||
if rank % 2 == 0:
|
||||
need_meta = True
|
||||
need_meta = send_tensor_meta(tensor, need_meta)
|
||||
logger.info('Rank {} shape sent (need meta: {}).'.format(
|
||||
rank, need_meta))
|
||||
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
|
||||
req.wait()
|
||||
out = tensor.clone()
|
||||
logger.info('Rank {} test op: tensor sent.'.format(rank))
|
||||
else:
|
||||
recv_tensor_shape = recv_tensor_meta(None)
|
||||
logger.info('Rank {} shape received. Correct shape: {}'.format(
|
||||
rank, tensor_shape == recv_tensor_shape))
|
||||
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
|
||||
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
|
||||
req.wait()
|
||||
logger.info('Rank {} test op: received tensor ({})'.format(
|
||||
rank, out.shape))
|
||||
|
||||
logger.info('Rank {} test op. Correct tensor: {}'.format(
|
||||
rank, check_equal(tensor, out)))
|
||||
|
||||
|
||||
def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
def check_comm(size, rank, prev_rank, next_rank, logger):
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
@@ -117,7 +84,6 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||
dist.all_reduce(tensor)
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.all_reduce(grad)
|
||||
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
|
||||
check_forward(tensor, rank, logger)
|
||||
check_backward(grad, rank, logger)
|
||||
check_forward_backward(tensor, grad, rank, logger)
|
||||
@@ -135,18 +101,13 @@ def run_check(rank, world_size, port):
|
||||
logger = get_dist_logger()
|
||||
rank = gpc.get_global_rank()
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
|
||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
|
||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
||||
logger.info(
|
||||
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format(
|
||||
rank, prev_rank, up_ranks, next_rank, down_ranks))
|
||||
'Rank {0}: prev rank {1}, next rank {2}'.format(
|
||||
rank, prev_rank, next_rank))
|
||||
logger.info('Distributed environment is initialzied.')
|
||||
|
||||
check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group,
|
||||
logger)
|
||||
check_comm(world_size, rank, prev_rank, next_rank, logger)
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user