mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 19:55:03 +00:00
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> * Split conv2d, class token, positional embedding in 2d, Fix random number in ddp Fix convergence in cifar10, Imagenet1000 * Integrate 1d tensor parallel in Colossal-AI (#39) * fixed 1D and 2D convergence (#38) * optimized 2D operations * fixed 1D ViT convergence problem * Feature/ddp (#49) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * support torch ddp * fix loss accumulation * add log for ddp * change seed * modify timing hook Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * Feature/pipeline (#40) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * optimize communication of pipeline parallel * fix grad clip for pipeline Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51) * Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset * update api for better usability (#58) update api for better usability Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
parallel_mode: ParallelMode, async_op=False) -> Tensor:
|
||||
"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
specific dimension.
|
||||
|
||||
@@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
shape = list(temp.shape)
|
||||
shape[dim] *= depth
|
||||
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
|
||||
out = list(torch.chunk(out, depth, dim=dim))
|
||||
out = [val.contiguous() for val in out]
|
||||
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
|
||||
out = torch.cat(out, dim=dim)
|
||||
return out
|
||||
# shape = list(temp.shape)
|
||||
# shape[dim] *= depth
|
||||
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
|
||||
# out = list(torch.chunk(out, depth, dim=dim))
|
||||
# out = [val.contiguous() for val in out]
|
||||
shape = [1] * len(tensor.shape)
|
||||
shape[dim] = depth
|
||||
out = tensor.repeat(shape)
|
||||
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
|
||||
op = dist.all_gather(tensor_list=out,
|
||||
tensor=temp,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
# out = torch.cat(out, dim=dim)
|
||||
if async_op:
|
||||
return out, op
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def reduce_scatter(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
parallel_mode: ParallelMode, async_op=False) -> Tensor:
|
||||
"""Reduces all tensors then scatters it in a specific dimension to all
|
||||
members in the parallel group.
|
||||
|
||||
@@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = list(torch.chunk(tensor, depth, dim=dim))
|
||||
temp = [val.contiguous() for val in temp]
|
||||
out = torch.empty(temp[0].shape,
|
||||
dtype=temp[0].dtype,
|
||||
device=get_current_device())
|
||||
dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
group=gpc.get_group(parallel_mode))
|
||||
return out
|
||||
# temp = list(torch.chunk(tensor, depth, dim=dim))
|
||||
# temp = [val.contiguous() for val in temp]
|
||||
# out = torch.zeros(temp[0].shape,
|
||||
# dtype=temp[0].dtype,
|
||||
# device=get_current_device())
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = temp[0].clone()
|
||||
op = dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
if async_op:
|
||||
return out, op
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def scatter(tensor: Tensor, src: int, dim: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
"""Scatters in a specific dimension from source rank to all ranks in
|
||||
the parallel group.
|
||||
def all_reduce(tensor: Tensor,
|
||||
parallel_mode: ParallelMode,
|
||||
async_op=False) -> Tensor:
|
||||
op = dist.all_reduce(tensor,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
if async_op:
|
||||
return tensor, op
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
# def scatter(tensor: Tensor, src: int, dim: int,
|
||||
# parallel_mode: ParallelMode) -> Tensor:
|
||||
# """Scatters in a specific dimension from source rank to all ranks in
|
||||
# the parallel group.
|
||||
|
||||
:param tensor: Tensor to be scattered
|
||||
:param dim: The dimension scattering in
|
||||
:param parallel_mode: Parallel group mode used in this communication
|
||||
:type tensor: Tensor
|
||||
:type dim: int
|
||||
:type parallel_mode: ParallelMode
|
||||
:return: The tensor generated by scatter
|
||||
:rtype: Tensor
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||||
return out
|
||||
# :param tensor: Tensor to be scattered
|
||||
# :param dim: The dimension scattering in
|
||||
# :param parallel_mode: Parallel group mode used in this communication
|
||||
# :type tensor: Tensor
|
||||
# :type dim: int
|
||||
# :type parallel_mode: ParallelMode
|
||||
# :return: The tensor generated by scatter
|
||||
# :rtype: Tensor
|
||||
# """
|
||||
# depth = gpc.get_world_size(parallel_mode)
|
||||
# temp = tensor.clone()
|
||||
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||||
# rank = gpc.get_local_rank(parallel_mode)
|
||||
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||||
# return out
|
||||
|
Reference in New Issue
Block a user