mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-25 01:40:08 +00:00
[zero] polish ZeroInitContext (#540)
This commit is contained in:
@@ -9,11 +9,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
@@ -32,8 +32,7 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
|
||||
target_device=torch.cuda.current_device(),
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
colo_model = model_builder(checkpoint=True)
|
||||
|
||||
Reference in New Issue
Block a user