mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-25 01:40:08 +00:00
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish comments
This commit is contained in:
@@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user