mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[tensor] remove gpc in tensor tests (#1186)
This commit is contained in:
@@ -11,11 +11,9 @@ from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from _utils import tensor_shard_equal, set_seed
|
||||
|
||||
|
||||
def init_1d_row_linear(weight, pg: ProcessGroup):
|
||||
@@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
@@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name):
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
@@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str):
|
||||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
|
||||
set_seed(1)
|
||||
@@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
|
Reference in New Issue
Block a user