[hotfix] adapt ProcessGroup and Optimizer to ColoTensor (#1388)

This commit is contained in:
HELSON
2022-07-29 19:33:24 +08:00
committed by GitHub
parent ad678921db
commit c7221cb2d4
7 changed files with 20 additions and 16 deletions

View File

@@ -116,9 +116,9 @@ class HybridAdam(NVMeOptimizer):
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
self._post_state_init(p)
state['step'] += 1