[WIP] Applying ColoTensor on TP-1D-row Linear. (#831)

* revert zero tensors back

* [tensor] init row 1d linear
This commit is contained in:
Jiarui Fang
2022-04-22 14:03:26 +08:00
committed by GitHub
parent 595bedf767
commit ac88de6dfc
3 changed files with 101 additions and 8 deletions

View File

@@ -19,12 +19,18 @@ def colo_linear(types, args, kwargs, pg):
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
if weight.shard_spec == None:
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
elif weight.shard_spec == '1Drow':
# TODO(jzy): implement 1Drow TP linear here.
raise NotImplementedError
else:
raise NotImplementedError
else:
return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@@ -1,6 +1,6 @@
import torch
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
from typing import Tuple, Optional
class ColoTensor(object):
@@ -21,20 +21,35 @@ class ColoTensor(object):
requires_grad=False,
pin_memory=False,
torch_tensor=torch.empty(0),
shard_spec: str = None,
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
@property
def shard_spec(self) -> Optional[str]:
return self._shard_spec
@property
def data(self):
return self._torch_tensor.data
@property
def grad(self):
return self._torch_tensor.grad
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
def init_from_torch_tensor(tensor: torch.Tensor, shard_spec: str = None) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
torch_tensor=tensor,
shard_spec=shard_spec)
return colo_t
def del_torch_tensor(self) -> None:
@@ -67,7 +82,5 @@ class ColoTensor(object):
if kwargs is None:
kwargs = {}
kwargs = {
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)