[zero] hijack p.grad in sharded model (#554)

* hijack p.grad in sharded model

* polish comments

* polish comments
This commit is contained in:
ver217
2022-03-30 18:14:50 +08:00
committed by GitHub
parent f552b11294
commit 014bac0c49
6 changed files with 45 additions and 55 deletions

View File

@@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue

View File

@@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
# Test get memory usage
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
@@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = StatefulTensor(None)
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
@@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert cuda_mem_use == 0
# reuse torch grad for sparam
sparam.fp32_grad = StatefulTensor(param.grad)
sparam.saved_grad = StatefulTensor(param.grad)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0