mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[zero] sharded optim support hybrid cpu adam (#486)
* sharded optim support hybrid cpu adam * update unit test * polish docstring
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CPUAdam(torch.optim.Optimizer):
|
||||
optimizer_id = 0
|
||||
# Number of fp32 shards for per parameter
|
||||
# Param weight, grad, momentum and variance
|
||||
num_fp32_shards_per_param = 4
|
||||
|
||||
def __init__(self,
|
||||
model_params,
|
||||
@@ -106,10 +110,6 @@ class CPUAdam(torch.optim.Optimizer):
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
|
||||
elif target_device.type == 'cuda':
|
||||
# FIXME() prepare grad on cuda
|
||||
if p.grad.device.type == 'cpu':
|
||||
p.grad = p.grad.to(target_device)
|
||||
|
||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
|
||||
|
Reference in New Issue
Block a user