diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 3dada00cd..f28b3e82c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -11,6 +11,7 @@ import torch.nn as nn from torch.optim import Optimizer from colossalai.tensor.d_tensor.d_tensor import DTensor +from .index_file import CheckpointIndexFile SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin"