[zero] sharded model support the reuse of fp16 shard (#495)

* sharded model supports reuse fp16 shard

* rename variable

* polish code

* polish code

* polish code
This commit is contained in:
ver217
2022-03-23 14:59:59 +08:00
committed by GitHub
parent f24b5ed201
commit 9ec1ce6ab1
7 changed files with 62 additions and 42 deletions

View File

@@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer):
bias_correction2,
loss_scale,
use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
if loss_scale is not None:
grad.div_(loss_scale)