diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 1eb03d1a7..56d224307 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -29,6 +29,7 @@ def convert_to_zero(model: nn.Module, :return: (model, optimizer) :rtype: Tuple """ + import deepspeed assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided' model = NaiveAMPModel(model, output_to_fp32=False)