[zero] fix init device bug in zero init context unittest (#516)

This commit is contained in:
Jiarui Fang
2022-03-25 12:24:18 +08:00
committed by GitHub
parent a513164379
commit 0bebda6ea5
8 changed files with 55 additions and 37 deletions

View File

@@ -19,17 +19,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device, shard_strategy_class):
def run_model_test(init_device_type, shard_strategy_class):
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
model_numel_tensor = torch.zeros(1, dtype=torch.int)
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
continue
with ZeroInitContext(convert_fp16=True,
target_device=init_device,
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor):
model_numel_tensor=model_numel_tensor,
rm_torch_payload_on_the_fly=False):
model = model_builder(checkpoint=True)
for param in model.parameters():
@@ -38,11 +45,9 @@ def run_model_test(init_device, shard_strategy_class):
assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
if init_device.type == 'cuda':
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
GLOBAL_MODEL_DATA_TRACER.clear()
def run_dist(rank, world_size, port):