diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 8ce6d7335..cf404038c 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,5 +1,6 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json +import warnings from dataclasses import asdict, dataclass from typing import Dict, List, Optional, Tuple @@ -8,8 +9,10 @@ from safetensors.torch import _TYPES, load_file, safe_open try: from tensornvme.async_file_io import AsyncFileWriter -except ModuleNotFoundError: - raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") +except Exception: + warnings.warn( + "Please install the latest tensornvme to use async save. pip install git+https://github.com/hpcaitech/TensorNVMe.git" + ) _TYPES_INV = {v: k for k, v in _TYPES.items()} import io