mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)
* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12 * [zero] add cpu shard init * [zero] add tiny example test * [colo_tensor] fix bugs for torch-1.11
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from .const import TensorType
|
||||
from copy import copy
|
||||
import torch
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup, ReplicaSpec
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional, Set, Callable
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
|
||||
from .const import TensorType
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
@@ -57,25 +58,26 @@ class ColoTensor(torch.Tensor):
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
"""
|
||||
torch_minor = int(torch.__version__.split('.')[1])
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
"""
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
@@ -112,7 +114,7 @@ class ColoTensor(torch.Tensor):
|
||||
return self.process_group
|
||||
|
||||
def set_process_group(self, pg: ProcessGroup):
|
||||
"""set_process_group
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
|
||||
@@ -135,7 +137,7 @@ class ColoTensor(torch.Tensor):
|
||||
return self.process_group.tp_world_size()
|
||||
|
||||
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||
"""set_dist_spec
|
||||
"""set_dist_spec
|
||||
set dist spec and change the payloads.
|
||||
|
||||
Args:
|
||||
@@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
|
||||
if func in _COLOSSAL_OPS:
|
||||
func = _COLOSSAL_OPS[func]
|
||||
|
||||
if cls.torch_minor >= 12:
|
||||
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||
# we have to capture the `backward` function
|
||||
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||
if func is torch.Tensor.backward:
|
||||
assert len(args) == 1 # only has 1 paramter
|
||||
backward_tensor = torch.Tensor(args[0])
|
||||
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
||||
return backward_tensor.backward(**tensor_kwargs)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if func in _get_my_nowrap_functions():
|
||||
@@ -178,7 +190,7 @@ class ColoTensor(torch.Tensor):
|
||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
|
||||
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
|
||||
@@ -191,12 +203,12 @@ class ColoTensor(torch.Tensor):
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
|
||||
|
||||
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
||||
DP process group not changed.
|
||||
|
||||
|
||||
2. If the pg is not not None and not equal to the current process group.
|
||||
First, convert the tensor as replicated among the TP process group.
|
||||
Second, reset the process group to the new pg.
|
||||
@@ -220,7 +232,7 @@ class ColoTensor(torch.Tensor):
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
"""to_replicate_
|
||||
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
|
Reference in New Issue
Block a user