remove gather out in parallel action (#1163)

This commit is contained in:
Jiarui Fang
2022-06-23 16:35:05 +08:00
committed by GitHub
parent 51f1ec96b0
commit 177c374401
8 changed files with 43 additions and 32 deletions

View File

@@ -41,6 +41,7 @@ def run_with_spec(spec_init_func):
x = torch.rand(2, 4).cuda()
out = model(x)
colo_out = F.linear(x, weight, bias)
colo_out = colo_out.to_replicate()
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)

View File

@@ -26,10 +26,10 @@ def init_1d_row_linear(weight):
weight.set_spec(spec)
def init_1d_col_linear(weight, gather_out=True):
def init_1d_col_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -98,7 +98,7 @@ def run_1d_hybrid_tp(model_name):
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, gather_out=False)
init_1d_col_linear(p)
model = model.cuda()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)