[polish] rename col_attr -> colo_attr (#558)

This commit is contained in:
Jiarui Fang
2022-03-31 12:25:45 +08:00
committed by GitHub
parent 2c45efc398
commit 7675366fce
9 changed files with 91 additions and 91 deletions

View File

@@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
@@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue