mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 04:40:36 +00:00
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish comments
This commit is contained in:
@@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
@@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = StatefulTensor(None)
|
||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
@@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# reuse torch grad for sparam
|
||||
sparam.fp32_grad = StatefulTensor(param.grad)
|
||||
sparam.saved_grad = StatefulTensor(param.grad)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
Reference in New Issue
Block a user