[test] make zero engine test really work (#447)

This commit is contained in:
Jiarui Fang
2022-03-17 17:24:25 +08:00
committed by GitHub
parent bb2790cf0b
commit 0fcfb1e00d
7 changed files with 39 additions and 28 deletions

View File

@@ -20,6 +20,7 @@ class CPUAdam(torch.optim.Optimizer):
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA.
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id
@@ -34,7 +35,8 @@ class CPUAdam(torch.optim.Optimizer):
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
def __del__(self):
self.cpu_adam_op.destroy_adam(self.opt_id)
if self.cpu_adam_op:
self.cpu_adam_op.destroy_adam(self.opt_id)
def torch_adam_update(self,
data,
@@ -72,7 +74,6 @@ class CPUAdam(torch.optim.Optimizer):
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():