diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 5f087ecab..d61ea5373 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -2,7 +2,6 @@ import functools from collections import OrderedDict from typing import Any, Optional, Iterator, Tuple from copy import deepcopy -from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX import itertools import torch import torch.distributed as dist @@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + class ShardedModelV2(nn.Module): """