improved allgather & reducescatter for 3d

This commit is contained in:
zbian
2023-01-03 15:26:47 +08:00
committed by アマデウス
parent c719798abe
commit e94c79f15b
4 changed files with 43 additions and 29 deletions

View File

@@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function):
ctx.output_parallel_mode = output_parallel_mode
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, -1, weight_parallel_mode)
weight = all_gather(weight, 0, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
@@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function):
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait()
@@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
ctx.weight_id = weight_id
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)

View File

@@ -196,8 +196,8 @@ class Linear3D(ParallelLayer):
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.skip_bias_add = skip_bias_add
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
self.in_features_per_partition = divide(in_features, self.depth**2)
self.out_features_per_partition = divide(out_features, self.depth)
self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter(
@@ -287,7 +287,7 @@ class Linear3D(ParallelLayer):
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
weight_key: 0,
bias_key: 0
},
partition_states={
@@ -310,7 +310,7 @@ class Linear3D(ParallelLayer):
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
weight_key: 0,
bias_key: 0
},
partition_states={