mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from tests.test_tensor._utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def init_zero(model_builder, placement_policy):
|
||||
@@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
|
||||
|
||||
model.train()
|
||||
model_copy.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
pg = ProcessGroup()
|
||||
set_seed(pg.dp_local_rank())
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
data, label = map(lambda x: x.cuda(), next(data_iter))
|
||||
|
@@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def init_zero(model, use_chunk, use_zero, placement_policy):
|
||||
@@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
return ZeroDDP(model, gemini_manager)
|
||||
pg = ProcessGroup()
|
||||
return ZeroDDP(model, gemini_manager, pg)
|
||||
|
||||
|
||||
def run_step(model, optim, criterion, data, label):
|
||||
|
Reference in New Issue
Block a user