moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
This commit is contained in:
アマデウス
2022-02-14 11:15:02 +08:00
committed by Frank Lee
parent b82d60be02
commit 9ee197d0e9
63 changed files with 4304 additions and 1040 deletions

View File

@@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
"""
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = [tensor]
out = tensor
work = None
else:
shape = list(tensor.shape)
@@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor,
async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out