[Tensor] add from_pretrained support and bert pretrained test (#921)

* add from_pretrained support and test

* polish

* polish

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-09 16:11:47 +08:00
committed by GitHub
parent 1d625fcd36
commit c195d2814c
4 changed files with 158 additions and 20 deletions

View File

@@ -1,7 +1,7 @@
from .op_wrapper import _COLOSSAL_OPS
import torch
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, Union
from numpy import product
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide
@@ -55,6 +55,15 @@ class ColoTensor(object):
def data(self):
return self._torch_tensor.data
@data.setter
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
if isinstance(tensor, ColoTensor):
self._torch_tensor.data = tensor.data
elif isinstance(tensor, torch.Tensor):
self._torch_tensor.data = tensor
else:
raise NotImplementedError
@property
def grad(self):
return self._torch_tensor.grad
@@ -148,14 +157,31 @@ class ColoTensor(object):
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
dim = self._get_gather_dim()
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA
self._size = self._torch_tensor.size()
def global_torch_tensor(self) -> torch.Tensor:
out_tensor = self.torch_tensor()
if self.is_gathered():
return out_tensor
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
world_size = gpc.get_world_size(parallel_action.parallel_mode)
if world_size == 1:
return out_tensor
rank = gpc.get_local_rank(parallel_action.parallel_mode)
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
tensor_list[rank] = out_tensor
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
dim = self._get_gather_dim()
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
return out_tensor
def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA
@@ -212,9 +238,7 @@ class ColoTensor(object):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name):
def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs):
# transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
@@ -225,7 +249,9 @@ class ColoTensor(object):
return execute_func
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
if hasattr(self._torch_tensor, name) == False:
raise AttributeError
attr = getattr(self._torch_tensor, name)
if isinstance(attr, Callable):
@@ -244,3 +270,12 @@ class ColoTensor(object):
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
for output in outputs
])
def _get_gather_dim(self):
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
else:
raise NotImplementedError
return dim