[zero] adapt zero hooks for unsharded module (#699)

This commit is contained in:
HELSON
2022-04-08 20:23:26 +08:00
committed by GitHub
parent 896ade15d6
commit ee112fe1da
12 changed files with 71 additions and 59 deletions

View File

@@ -126,16 +126,15 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'

View File

@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class()

View File

@@ -58,15 +58,15 @@ def _run_shard_param_v2(rank, world_size, port):
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
sparam.remove_torch_payload()
assert (param.data.numel() == 1)
assert (param.data.numel() == 0)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cpu_mem_use == 2 * 3 * 4 * 2
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0
# append a grad to torch param

View File

@@ -56,4 +56,4 @@ def test_zero_state_dict(world_size):
if __name__ == '__main__':
test_zero_state_dict(2, TensorShardStrategy)
test_zero_state_dict(2)