mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[hotfix] fix error for torch 2.0 (#2243)
This commit is contained in:
@@ -69,6 +69,7 @@ class ColoTensor(torch.Tensor):
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
"""
|
||||
torch_major = int(torch.__version__.split('.')[0])
|
||||
torch_minor = int(torch.__version__.split('.')[1])
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
@@ -168,7 +169,7 @@ class ColoTensor(torch.Tensor):
|
||||
if func in _COLOSSAL_OPS:
|
||||
func = _COLOSSAL_OPS[func]
|
||||
|
||||
if cls.torch_minor >= 12:
|
||||
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
|
||||
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||
# we have to capture the `backward` function
|
||||
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||
|
Reference in New Issue
Block a user