[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
This commit is contained in:
HELSON
2022-11-02 16:11:34 +08:00
committed by GitHub
parent 32c1b843a9
commit c6a1a62636
9 changed files with 1041 additions and 951 deletions

View File

@@ -1,14 +1,15 @@
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from functools import lru_cache
from typing import Callable, Optional, Set
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import ProcessGroup, ReplicaSpec
import torch
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None)
@@ -57,25 +58,26 @@ class ColoTensor(torch.Tensor):
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
@@ -112,7 +114,7 @@ class ColoTensor(torch.Tensor):
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
@@ -135,7 +137,7 @@ class ColoTensor(torch.Tensor):
return self.process_group.tp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
"""set_dist_spec
set dist spec and change the payloads.
Args:
@@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12:
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions():
@@ -178,7 +190,7 @@ class ColoTensor(torch.Tensor):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
@@ -191,12 +203,12 @@ class ColoTensor(torch.Tensor):
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
@@ -220,7 +232,7 @@ class ColoTensor(torch.Tensor):
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""