diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py index d8bbaad8f..02eae80b9 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module): out = self._convert_to_fp32(out) elif isinstance(out, (tuple, list)): out = [self._convert_to_fp32(val) for val in out] + elif isinstance(out, dict): + out = {key: self._convert_to_fp32(val) for key, val in out.items()} return out