From 6690a61b4dab4176a445241d97cc37557b74d528 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 21 Jun 2022 11:33:53 +0800 Subject: [PATCH] [hotfix] prevent nested ZeRO (#1140) --- colossalai/zero/sharded_model/sharded_model_v2.py | 1 + colossalai/zero/sharded_optim/sharded_optim_v2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index d61ea5373..5e06eb646 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module): tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False): + assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' super().__init__() self.logger = get_dist_logger() diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 610665422..95ab70708 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -87,6 +87,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): mp_process_group: Optional[ProcessGroup] = None, verbose: bool = False) -> None: assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' + assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy