mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[zero] fix init device bug in zero init context unittest (#516)
This commit is contained in:
@@ -19,17 +19,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device, shard_strategy_class):
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
if init_device_type == 'cuda':
|
||||
init_device = torch.device(f"cuda:{get_current_device()}")
|
||||
elif init_device_type == 'cpu':
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
continue
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=init_device,
|
||||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor):
|
||||
model_numel_tensor=model_numel_tensor,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
@@ -38,11 +45,9 @@ def run_model_test(init_device, shard_strategy_class):
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
print(f'numel {model_numel_tensor}')
|
||||
if init_device.type == 'cuda':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
if init_device.type == 'cuda':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user