[amp] included dict for type casting of model output (#1102)

This commit is contained in:
Frank Lee
2022-06-13 14:18:04 +08:00
committed by GitHub
parent 5a9d8ef4d5
commit 72bd7c696b

View File

@@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out