[zero] update zero context init with the updated test utils (#327)

This commit is contained in:
Jiarui Fang
2022-03-08 14:45:01 +08:00
committed by GitHub
parent 6afc4f9e11
commit cec05b25c9
10 changed files with 96 additions and 49 deletions

View File

@@ -9,22 +9,27 @@ import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.init_ctx import ZeroInitContext
from common import CONFIG, Net
from common import CONFIG
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=TensorShardStrategy(), shard_param=True):
# Note Net(checkpoint=True).cuda() moving to cuda is useless
model = Net(checkpoint=True)
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
with ZeroInitContext(convert_fp16=True,
convert_cuda=True,
shard_strategy=TensorShardStrategy(),
shard_param=True):
model = model_builder(checkpoint=True)
for param in model.parameters():
assert hasattr(param, 'ca_attr')
assert param.ca_attr.data.dtype == torch.half
assert param.ca_attr._data_sharded_tensor.is_sharded
assert param.ca_attr.data.device.type == 'cuda'
for param in model.parameters():
assert hasattr(param, 'ca_attr')
assert param.ca_attr.data.dtype == torch.half
assert param.ca_attr._data_sharded_tensor.is_sharded
assert param.ca_attr.data.device.type == 'cuda'
@pytest.mark.dist

View File

@@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port):
sparam = ShardedParamV2(param=param, process_group=None)
allclose(sparam.data, param_ref.data)
sparam.remove_torch_payload()
assert (param.data.numel() == 1)