fixed zero level 3 dtype bug (#76)

This commit is contained in:
Frank Lee
2021-12-20 17:00:53 +08:00
committed by GitHub
parent 632e622de8
commit 91c327cb44
5 changed files with 16 additions and 12 deletions

View File

@@ -1,7 +1,6 @@
from .apex_amp import ApexAMPOptimizer
import torch.nn as nn
from torch.optim import Optimizer
import apex.amp as apex_amp
def convert_to_apex_amp(model: nn.Module,
@@ -19,6 +18,7 @@ def convert_to_apex_amp(model: nn.Module,
:return: (model, optimizer)
:rtype: Tuple
"""
import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer