mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] trace states of fp16/32 grad and fp32 param (#571)
This commit is contained in:
@@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
|
||||
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = StatefulTensor(None)
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
|
||||
Reference in New Issue
Block a user