mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[ColoTensor] add independent process group (#1179)
This commit is contained in:
@@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
@@ -18,34 +18,30 @@ from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
|
||||
def init_1d_row_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_linear(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
@@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name):
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
@@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name):
|
||||
# print(name)
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p)
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
if 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p)
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
if 'position_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p)
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
if 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p)
|
||||
init_1d_col_embedding(p, pg)
|
||||
elif "simple_net" == model_name:
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p)
|
||||
init_1d_col_embedding(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p)
|
||||
init_1d_col_linear(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
init_1d_row_linear(p)
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p)
|
||||
init_1d_col_linear(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
@@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
@@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
|
||||
set_seed(1)
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
@@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str):
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
||||
init_1d_row_linear(p)
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p)
|
||||
init_1d_row_embedding(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
@@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port):
|
||||
|
||||
# The test case has to download huggingface pretrained models from the internet
|
||||
# So we manually trigger the test.
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def _test_pretrain_load(world_size):
|
||||
def test_pretrain_load(world_size):
|
||||
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
@@ -342,4 +346,4 @@ if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
# test_model(4)
|
||||
_test_pretrain_load(4)
|
||||
test_pretrain_load(4)
|
||||
|
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
|
||||
@@ -21,14 +21,6 @@ def test_tensor_indexing():
|
||||
assert allclose(torch_t[:, 1], colo_t[:, 1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# FIXME(ver217): support lazy init
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor.numel() == 0
|
||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||
|
||||
|
||||
def test_wrapped_tensor_func():
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone())
|
||||
@@ -62,10 +54,12 @@ def test_operand():
|
||||
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref,
|
||||
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
|
||||
num_partitions=[world_size])))
|
||||
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
@@ -81,8 +75,10 @@ def _run_view(world_size):
|
||||
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
print(gpc.get_group(ParallelMode.DATA).size())
|
||||
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
|
Reference in New Issue
Block a user