[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang
2022-07-04 18:54:37 +08:00
committed by GitHub
parent abf6a262dc
commit 060b917daf
33 changed files with 499 additions and 357 deletions

View File

@@ -1,12 +1,10 @@
import pytest
import colossalai
import torch
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from functools import partial
from tests.test_tensor._utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.tensor import ProcessGroup
def init_zero(model_builder, placement_policy):
@@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
model.train()
model_copy.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))
pg = ProcessGroup()
set_seed(pg.dp_local_rank())
data_iter = iter(train_dataloader)
data, label = map(lambda x: x.cuda(), next(data_iter))

View File

@@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.testing import parameterize
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy):
@@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(model, gemini_manager)
pg = ProcessGroup()
return ZeroDDP(model, gemini_manager, pg)
def run_step(model, optim, criterion, data, label):