Optimize pipeline schedule (#94)

* add pipeline shared module wrapper and update load batch

* added model parallel process group for amp and clip grad (#86)

* added model parallel process group for amp and clip grad

* update amp and clip with model parallel process group

* remove pipeline_prev/next group (#88)

* micro batch offload

* optimize pipeline gpu memory usage

* pipeline can receive tensor shape (#93)

* optimize pipeline gpu memory usage

* fix grad accumulation step counter

* rename classes and functions

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
ver217
2021-12-30 15:56:46 +08:00
committed by GitHub
parent e5b9f9a08d
commit 96780e6ee4
29 changed files with 423 additions and 290 deletions

View File

@@ -155,22 +155,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
ops = []
# Take max across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
ops.append(dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(
ParallelMode.TENSOR),
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
ops.append(dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(
ParallelMode.PIPELINE),
async_op=True))
for req in ops:
req.wait()
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
async_op=False)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []

View File

@@ -65,6 +65,7 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
if no_sync:
@@ -81,7 +82,7 @@ class GradAccumDataloader():
be update only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle.
Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader,
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
:param dataloader: your dataloader object
:type dataloader: Iterable
:param accumulate_size: the number of steps to accumulate gradients