mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Tensor] add from_pretrained support and bert pretrained test (#921)
* add from_pretrained support and test * polish * polish * polish * polish
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional, Callable
|
||||
from typing import Tuple, Optional, Callable, Union
|
||||
from numpy import product
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.utils import divide
|
||||
@@ -55,6 +55,15 @@ class ColoTensor(object):
|
||||
def data(self):
|
||||
return self._torch_tensor.data
|
||||
|
||||
@data.setter
|
||||
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
|
||||
if isinstance(tensor, ColoTensor):
|
||||
self._torch_tensor.data = tensor.data
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
self._torch_tensor.data = tensor
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._torch_tensor.grad
|
||||
@@ -148,14 +157,31 @@ class ColoTensor(object):
|
||||
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
|
||||
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||
if self._shard_pattern == ShardPattern.Row:
|
||||
dim = 0
|
||||
elif self._shard_pattern == ShardPattern.Col:
|
||||
dim = -1
|
||||
dim = self._get_gather_dim()
|
||||
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
||||
self._shard_pattern = ShardPattern.NA
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
def global_torch_tensor(self) -> torch.Tensor:
|
||||
out_tensor = self.torch_tensor()
|
||||
if self.is_gathered():
|
||||
return out_tensor
|
||||
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||
world_size = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
if world_size == 1:
|
||||
return out_tensor
|
||||
|
||||
rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = out_tensor
|
||||
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
|
||||
|
||||
dim = self._get_gather_dim()
|
||||
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return out_tensor
|
||||
|
||||
def is_gathered(self) -> bool:
|
||||
return self._shard_pattern == ShardPattern.NA
|
||||
|
||||
@@ -212,9 +238,7 @@ class ColoTensor(object):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
def replace_tensor_with_colo(func):
|
||||
|
||||
def execute_func(*args, **kwargs):
|
||||
# transform the ColoTensor args to torch Tensor.
|
||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||
@@ -225,7 +249,9 @@ class ColoTensor(object):
|
||||
|
||||
return execute_func
|
||||
|
||||
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
|
||||
if hasattr(self._torch_tensor, name) == False:
|
||||
raise AttributeError
|
||||
|
||||
attr = getattr(self._torch_tensor, name)
|
||||
|
||||
if isinstance(attr, Callable):
|
||||
@@ -244,3 +270,12 @@ class ColoTensor(object):
|
||||
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
||||
for output in outputs
|
||||
])
|
||||
|
||||
def _get_gather_dim(self):
|
||||
if self._shard_pattern == ShardPattern.Row:
|
||||
dim = 0
|
||||
elif self._shard_pattern == ShardPattern.Col:
|
||||
dim = -1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dim
|
||||
|
Reference in New Issue
Block a user