[ColoTensor] add independent process group (#1179)

This commit is contained in:
Jiarui Fang
2022-06-29 10:03:09 +08:00
committed by GitHub
parent 26ba87272d
commit 7487215b95
4 changed files with 116 additions and 45 deletions

View File

@@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer
@@ -18,34 +18,30 @@ from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
def init_1d_row_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_row_linear(weight, pg: ProcessGroup):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
def init_1d_col_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_col_linear(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_row_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_col_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
@@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name):
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data)
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
if 'bert' == model_name:
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
@@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name):
# print(name)
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p)
init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p)
init_1d_row_embedding(p, pg)
# num_class = seq_len = 512 | (512, 8)
if 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p)
init_1d_row_embedding(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
if 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p)
init_1d_col_embedding(p, pg)
elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'embed' in name and 'weight' in name:
init_1d_col_embedding(p)
init_1d_col_embedding(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p)
init_1d_col_linear(p, pg)
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p)
init_1d_row_linear(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p)
init_1d_col_linear(p, pg)
model = model.cuda()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
@@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
@@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str):
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
@@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str):
if not isinstance(p, ColoTensor):
continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
init_1d_row_linear(p)
init_1d_row_linear(p, pg)
if 'embed' in name and 'weight' in name:
init_1d_row_embedding(p)
init_1d_row_embedding(p, pg)
model = model.cuda()
@@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port):
# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def _test_pretrain_load(world_size):
def test_pretrain_load(world_size):
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@@ -342,4 +346,4 @@ if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model(4)
_test_pretrain_load(4)
test_pretrain_load(4)

View File

@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
from colossalai.context import ParallelMode
from functools import partial
@@ -21,14 +21,6 @@ def test_tensor_indexing():
assert allclose(torch_t[:, 1], colo_t[:, 1])
@pytest.mark.skip
# FIXME(ver217): support lazy init
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def test_wrapped_tensor_func():
t_ref = torch.randn(4, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone())
@@ -62,10 +54,12 @@ def test_operand():
def _run_view(world_size):
t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)))
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
t = ColoTensor.from_torch_tensor(
t_ref,
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
num_partitions=[world_size])))
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5
@@ -81,8 +75,10 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
print(gpc.get_group(ParallelMode.DATA).size())
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)))
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))