[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)

This commit is contained in:
Jiarui Fang
2022-06-24 13:08:54 +08:00
committed by GitHub
parent f4ef224358
commit 4b9bba8116
23 changed files with 116 additions and 105 deletions

View File

@@ -43,7 +43,7 @@ def init_1d_row(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col(weight, bias):
@@ -51,8 +51,8 @@ def init_1d_col(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)
weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func):
@@ -63,6 +63,7 @@ def run_with_spec(spec_init_func):
x = torch.rand(2, 16).cuda()
out = model(x)
colo_out = torch.addmm(bias, x, weight)
colo_out = colo_out.to_replicate()
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)

View File

@@ -20,7 +20,7 @@ def init_1d_col(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func):

View File

@@ -20,7 +20,7 @@ def init_1d_row(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col(weight):
@@ -28,7 +28,7 @@ def init_1d_col(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func):

View File

@@ -22,7 +22,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_spec(spec)
p.set_tensor_spec(spec)
def init_1d_col_spec(model):
@@ -32,7 +32,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec)
p.set_tensor_spec(spec)
def check_param_equal(model, torch_model):

View File

@@ -21,7 +21,7 @@ def init_1d_row(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col(weight, bias):
@@ -29,8 +29,8 @@ def init_1d_col(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)
weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func):

View File

@@ -23,7 +23,7 @@ def init_1d_row_linear(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col_linear(weight):
@@ -31,7 +31,7 @@ def init_1d_col_linear(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight):
@@ -39,7 +39,7 @@ def init_1d_row_embedding(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight):
@@ -47,7 +47,7 @@ def init_1d_col_embedding(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def run_1d_hybrid_tp(model_name):

View File

@@ -157,7 +157,7 @@ def run_check_shared_param():
col_spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
model.cls.predictions.bias.set_spec(col_spec)
model.cls.predictions.bias.set_tensor_spec(col_spec)
try:
check_colo_module(model.cls.predictions.decoder, recursive=False)
except Exception as e:

View File

@@ -36,10 +36,10 @@ def test_layernorm():
def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.spec.dist_spec):
for k in dir(tensor.tensor_spec.dist_spec):
if not k.startswith('__'):
assert hasattr(other.spec.dist_spec, k)
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k)
assert hasattr(other.tensor_spec.dist_spec, k)
assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k)
def check_element_wise_ops():

View File

@@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5))

View File

@@ -51,7 +51,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_spec(spec)
p.set_tensor_spec(spec)
def init_1d_col_spec(model):
@@ -61,7 +61,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec)
p.set_tensor_spec(spec)
@parameterize('use_chunk', [False, True])