mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
remove gather out in parallel action (#1163)
This commit is contained in:
@@ -41,6 +41,7 @@ def run_with_spec(spec_init_func):
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
|
@@ -26,10 +26,10 @@ def init_1d_row_linear(weight):
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, gather_out=True):
|
||||
def init_1d_col_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
@@ -98,7 +98,7 @@ def run_1d_hybrid_tp(model_name):
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
init_1d_row_linear(p)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, gather_out=False)
|
||||
init_1d_col_linear(p)
|
||||
|
||||
model = model.cuda()
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
|
Reference in New Issue
Block a user