mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-19 00:55:09 +00:00
[zero] adapt zero hooks for unsharded module (#699)
This commit is contained in:
@@ -126,16 +126,15 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
||||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
if zero_p.colo_attr.param_is_sharded:
|
||||
if reuse_fp16_shard:
|
||||
zero_p = zero_p.data.to(p.device).float()
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank].float()
|
||||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
||||
|
||||
@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
@@ -58,15 +58,15 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 1)
|
||||
assert (param.data.numel() == 0)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# append a grad to torch param
|
||||
|
||||
@@ -56,4 +56,4 @@ def test_zero_state_dict(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_state_dict(2, TensorShardStrategy)
|
||||
test_zero_state_dict(2)
|
||||
|
||||
Reference in New Issue
Block a user