[zero] sharded optim support hybrid cpu adam (#486)

* sharded optim support hybrid cpu adam

* update unit test

* polish docstring
This commit is contained in:
ver217
2022-03-22 14:56:59 +08:00
committed by GitHub
parent b334822163
commit 62b0a8d644
5 changed files with 64 additions and 48 deletions

View File

@@ -1,9 +1,13 @@
import torch
import math
import torch
class CPUAdam(torch.optim.Optimizer):
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(self,
model_params,
@@ -106,10 +110,6 @@ class CPUAdam(torch.optim.Optimizer):
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
elif target_device.type == 'cuda':
# FIXME() prepare grad on cuda
if p.grad.device.type == 'cpu':
p.grad = p.grad.to(target_device)
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"