[tensor] lazy init (#823)

This commit is contained in:
Jiarui Fang
2022-04-21 15:40:23 +08:00
committed by GitHub
parent 68dcd51d41
commit 2ecc3d7a55
2 changed files with 46 additions and 8 deletions

View File

@@ -1,16 +1,48 @@
import torch
from .op_wrapper import _COLOSSAL_OPS
from typing import Tuple
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def __new__(cls, *args, **kwargs):
return super(ColoTensor, cls).__new__(cls)
def __init__(self, t: torch.Tensor) -> None:
self._torch_tensor = t
def __init__(
self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
torch_tensor=None,
):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._torch_tensor = torch_tensor
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor):
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.pin_memory,
torch_tensor=tensor)
return colo_t
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor == None:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
requires_grad=self._requires_grad,
pin_memory=self._pin_memory)
return self._torch_tensor
@classmethod