mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-08 14:46:09 +00:00
[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)
This commit is contained in:
@@ -43,7 +43,7 @@ def init_1d_row(weight, bias):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
@@ -51,8 +51,8 @@ def init_1d_col(weight, bias):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
@@ -63,6 +63,7 @@ def run_with_spec(spec_init_func):
|
||||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
|
||||
@@ -20,7 +20,7 @@ def init_1d_col(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
||||
@@ -20,7 +20,7 @@ def init_1d_row(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
@@ -28,7 +28,7 @@ def init_1d_col(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
||||
@@ -22,7 +22,7 @@ def init_1d_row_spec(model):
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
@@ -32,7 +32,7 @@ def init_1d_col_spec(model):
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
|
||||
@@ -21,7 +21,7 @@ def init_1d_row(weight, bias):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
@@ -29,8 +29,8 @@ def init_1d_col(weight, bias):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
||||
@@ -23,7 +23,7 @@ def init_1d_row_linear(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight):
|
||||
@@ -31,7 +31,7 @@ def init_1d_col_linear(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight):
|
||||
@@ -39,7 +39,7 @@ def init_1d_row_embedding(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight):
|
||||
@@ -47,7 +47,7 @@ def init_1d_col_embedding(weight):
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_1d_hybrid_tp(model_name):
|
||||
|
||||
@@ -157,7 +157,7 @@ def run_check_shared_param():
|
||||
col_spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
model.cls.predictions.bias.set_spec(col_spec)
|
||||
model.cls.predictions.bias.set_tensor_spec(col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, recursive=False)
|
||||
except Exception as e:
|
||||
|
||||
@@ -36,10 +36,10 @@ def test_layernorm():
|
||||
|
||||
def check_spec_eq(tensor, other):
|
||||
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
||||
for k in dir(tensor.spec.dist_spec):
|
||||
for k in dir(tensor.tensor_spec.dist_spec):
|
||||
if not k.startswith('__'):
|
||||
assert hasattr(other.spec.dist_spec, k)
|
||||
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k)
|
||||
assert hasattr(other.tensor_spec.dist_spec, k)
|
||||
assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k)
|
||||
|
||||
|
||||
def check_element_wise_ops():
|
||||
|
||||
@@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
|
||||
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
assert t.shape == torch.Size((4 * world_size, 5))
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def init_1d_row_spec(model):
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
@@ -61,7 +61,7 @@ def init_1d_col_spec(model):
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
|
||||
Reference in New Issue
Block a user