mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[polish] rename col_attr -> colo_attr (#558)
This commit is contained in:
@@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
||||
model = MoeModel()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert hasattr(param, 'colo_attr')
|
||||
|
||||
# the weights in the gate should be fp32
|
||||
if 'gate' in name:
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
|
||||
else:
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
|
||||
# the parameters in moe experts and its gate should not be sharded
|
||||
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
|
||||
assert not param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert not param.colo_attr.sharded_data_tensor.is_sharded
|
||||
else:
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
|
@@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
@@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
||||
if reuse_fp16_shard:
|
||||
zero_p = zero_p.data.to(p.device).float()
|
||||
else:
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
@@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
|
Reference in New Issue
Block a user