mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 22:23:23 +00:00
[amp] included dict for type casting of model output (#1102)
This commit is contained in:
parent
5a9d8ef4d5
commit
72bd7c696b
@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
return out
|
||||
|
Loading…
Reference in New Issue
Block a user