[tensor] derive compute pattern from dist spec (#971)

* derive compute pattern from dist spec

* polish code
This commit is contained in:
ver217
2022-05-16 14:58:08 +08:00
committed by GitHub
parent 46bc95708f
commit c2fdc6a011
10 changed files with 79 additions and 65 deletions

View File

@@ -40,7 +40,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -55,7 +55,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@@ -17,7 +17,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
def init_1d_row(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -31,7 +31,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight):
def init_1d_col(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)

View File

@@ -18,7 +18,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
def init_1d_row(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -33,7 +33,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@@ -86,35 +86,43 @@ def set_seed(seed):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def init_1d_row_linear(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_col_linear(weight, gather_out=True):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \
gather_out=gather_out)])
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1D,
parallel_mode=ParallelMode.PARALLEL_1D,
gather_out=gather_out)
])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_row_embedding(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_col_embedding(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@@ -124,7 +132,7 @@ def run_1d_hybrid_tp(model_name):
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
@@ -173,7 +181,7 @@ def run_1d_hybrid_tp(model_name):
if rank == 0:
model_torch.eval()
colo_optimizer_torch.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
@@ -217,11 +225,11 @@ def run_1d_hybrid_tp(model_name):
assert torch.allclose(p1, p2)
else:
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
if p1.size(-1) < p2.size(-1): # col
if p1.size(-1) < p2.size(-1): # col
world_size = p2.size(-1) // p1.size(-1)
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
elif p1.size(0) < p2.size(0): # row
elif p1.size(0) < p2.size(0): # row
world_size = p2.size(0) // p1.size(0)
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
@@ -376,7 +384,7 @@ def _run_pretrain_load():
if isinstance(param, ColoParameter):
c1 += 1
else:
c2 +=1
c2 += 1
dict_col[name] = param
assert c_ref == c1
assert c2 == 0
@@ -395,6 +403,7 @@ def run_model_dist(rank, world_size, port):
for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
# @parameterize('world_size', [1, 4])