mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
improved allgather & reducescatter for 3d
This commit is contained in:
@@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function):
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight, -1, weight_parallel_mode)
|
||||
weight = all_gather(weight, 0, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
@@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function):
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
input_op.wait()
|
||||
@@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
|
||||
ctx.weight_id = weight_id
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
|
||||
weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
@@ -196,8 +196,8 @@ class Linear3D(ParallelLayer):
|
||||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||
self.in_features_per_partition = divide(in_features, self.depth**2)
|
||||
self.out_features_per_partition = divide(out_features, self.depth)
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
@@ -287,7 +287,7 @@ class Linear3D(ParallelLayer):
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
@@ -310,7 +310,7 @@ class Linear3D(ParallelLayer):
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
|
Reference in New Issue
Block a user