[tensor] reorganize files (#820)

This commit is contained in:
Jiarui Fang 2022-04-21 14:15:48 +08:00 committed by GitHub
parent ab962b9735
commit 0ce8924ceb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 71 additions and 76 deletions

View File

@ -1,3 +0,0 @@
from .init import stateful_uniform
from .linear import stateful_linear
from .element_wise import stateful_mean

View File

@ -1,17 +0,0 @@
from typing import (
Callable,
Dict,
)
# Custom sharded ops
_STATEFUL_OPS: Dict[str, Callable] = {}
def _register_stateful_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 4:
raise TypeError(f'Custom stateful op function expects signature: '
f'(types, args, kwargs, process_group), but received '
f'signature: {signature(func)}')
global _STATEFUL_OPS
_STATEFUL_OPS[op] = func

View File

@ -0,0 +1,7 @@
from .op_wrapper import (
colo_op_impl,)
from .colo_tensor import ColoTensor
from .utils import convert_parameter
from ._ops import *
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']

View File

@ -0,0 +1,3 @@
from .init import colo_uniform
from .linear import colo_linear
from .element_wise import colo_mean

View File

@ -1,17 +1,17 @@
import torch import torch
from colossalai.gemini.tensor import stateful_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 from colossalai.tensor import ColoTensor
@stateful_op_impl(torch.mean) @colo_op_impl(torch.mean)
def stateful_mean(types, args=(), kwargs=None, pg=None): def colo_mean(types, args=(), kwargs=None, pg=None):
stateful_tensor = args[0] stateful_tensor = args[0]
return torch.mean(stateful_tensor.torch_tensor()) return torch.mean(stateful_tensor.torch_tensor())
def register_elementwise_op(op): def register_elementwise_op(op):
@stateful_op_impl(op) @colo_op_impl(op)
def elementwise_op(types, args=(), kwargs=None, pg=None): def elementwise_op(types, args=(), kwargs=None, pg=None):
""" """
Handles ``__torch_function__`` dispatch for the elementwise op such Handles ``__torch_function__`` dispatch for the elementwise op such
@ -20,8 +20,8 @@ def register_elementwise_op(op):
""" """
input_tensor = args[0] input_tensor = args[0]
# Validate types # Validate types
if not isinstance(input_tensor, StatefulTensorV2): if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a StatefulTensorV2") raise TypeError("input needs to be a ColoTensor")
return op(input_tensor.torch_tensor()) return op(input_tensor.torch_tensor())

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.gemini.tensor import stateful_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
def validate_param(param, param_name): def validate_param(param, param_name):
@ -7,8 +7,8 @@ def validate_param(param, param_name):
raise ValueError(f"param: {param_name} shouldn't be None!") raise ValueError(f"param: {param_name} shouldn't be None!")
@stateful_op_impl(torch.nn.init.uniform_) @colo_op_impl(torch.nn.init.uniform_)
def stateful_uniform(types, args=(), kwargs=None, pg=None): def colo_uniform(types, args=(), kwargs=None, pg=None):
r""" r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`. distribution :math:`\mathcal{U}(a, b)`.

View File

@ -1,11 +1,11 @@
import torch import torch
from colossalai.gemini.tensor import stateful_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from ..stateful_tensor import StatefulTensorV2 from colossalai.tensor.colo_tensor import ColoTensor
from packaging import version from packaging import version
@stateful_op_impl(torch.nn.functional.linear) @colo_op_impl(torch.nn.functional.linear)
def stateful_linear(types, args, kwargs, pg): def colo_linear(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear. This method computes a linear.
""" """
@ -19,11 +19,11 @@ def stateful_linear(types, args, kwargs, pg):
bias = None bias = None
else: else:
bias = kwargs.get('bias', None) bias = kwargs.get('bias', None)
if isinstance(bias, StatefulTensorV2): if isinstance(bias, ColoTensor):
bias = bias.torch_tensor() bias = bias.torch_tensor()
# Add communication logic before and after linear call. # Add communication logic before and after linear call.
if isinstance(weight, StatefulTensorV2): if isinstance(weight, ColoTensor):
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
else: else:
return torch.nn.functional.linear(input_tensor, weight, bias) return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@ -1,11 +1,11 @@
import torch import torch
from .api import _STATEFUL_OPS from .op_wrapper import _COLOSSAL_OPS
class StatefulTensorV2(object): class ColoTensor(object):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
return super(StatefulTensorV2, cls).__new__(cls) return super(ColoTensor, cls).__new__(cls)
def __init__(self, t: torch.Tensor) -> None: def __init__(self, t: torch.Tensor) -> None:
self._torch_tensor = t self._torch_tensor = t
@ -15,16 +15,15 @@ class StatefulTensorV2(object):
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
global _STATEFUL_OPS global _COLOSSAL_OPS
if func in _STATEFUL_OPS: if func in _COLOSSAL_OPS:
# Find StatefulTensorV2 instance to get process_group.
for arg in args: for arg in args:
if isinstance(arg, StatefulTensorV2): if isinstance(arg, ColoTensor):
return _STATEFUL_OPS[func](types, args, kwargs, None) return _COLOSSAL_OPS[func](types, args, kwargs, None)
for kwarg in kwargs.values(): for kwarg in kwargs.values():
if isinstance(kwarg, StatefulTensorV2): if isinstance(kwarg, ColoTensor):
return _STATEFUL_OPS[func](types, args, kwargs, None) return _COLOSSAL_OPS[func](types, args, kwargs, None)
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and " raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for StatefulTensorV2!") f"kwargs: {kwargs} not supported for ColoTensor!")

View File

@ -1,24 +1,39 @@
from typing import (
Callable,
Dict,
)
import functools import functools
from .api import (
_register_stateful_op,) # Custom sharded ops
_COLOSSAL_OPS: Dict[str, Callable] = {}
def stateful_op_impl(func): def _register_colo_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 4:
raise TypeError(f'Custom stateful op function expects signature: '
f'(types, args, kwargs, process_group), but received '
f'signature: {signature(func)}')
global _COLOSSAL_OPS
_COLOSSAL_OPS[op] = func
def colo_op_impl(func):
""" """
Provides a way for users to write their own custom operator. This Provides a way for users to write their own custom operator. This
can be used to override existing StatefulTensorV2 operators or write a new can be used to override existing ColoTensor operators or write a new
one not supported by StatefulTensorV2. If the operator in question is covered one not supported by ColoTensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a StatefulTensorV2 as any of its by ``__torch_function__`` dispatch and has a ColoTensor as any of its
parameters, the function provided will be invoked for that operator. parameters, the function provided will be invoked for that operator.
Example:: Example::
>>> @stateful_op_impl(torch.nn.functional.linear) >>> @colo_op_impl(torch.nn.functional.linear)
>>> def my_custom_linear(types, args, kwargs, process_group): >>> def my_custom_linear(types, args, kwargs, process_group):
>>> .... >>> ....
>>> >>>
>>> input = torch.rand(10, 32) >>> input = torch.rand(10, 32)
>>> weight = StatefulTensorV2(torch.rand(32, 16)) >>> weight = ColoTensor(torch.rand(32, 16))
>>> bias = StatefulTensorV2(torch.rand(16)) >>> bias = ColoTensor(torch.rand(16))
>>> # This will call `my_custom_linear` instead of the default. >>> # This will call `my_custom_linear` instead of the default.
>>> torch.nn.functional.linear(input, weight, bias) >>> torch.nn.functional.linear(input, weight, bias)
@ -32,7 +47,7 @@ def stateful_op_impl(func):
""" """
def decorator_sharded_func(wrapped_func): def decorator_sharded_func(wrapped_func):
_register_stateful_op(func, wrapped_func) _register_colo_op(func, wrapped_func)
@functools.wraps(wrapped_func) @functools.wraps(wrapped_func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):

View File

@ -1,14 +1,10 @@
import torch import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2 from colossalai.tensor.colo_tensor import ColoTensor
def _convert_tensor(tensor: torch.Tensor) -> StatefulTensorV2: def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
if not tensor.is_contiguous(): return ColoTensor(tensor)
raise ValueError('input tensor is not a contiguous Tensor')
return StatefulTensorV2(tensor)
def convert_parameter(module: torch.nn.Module, param_name: str): def convert_parameter(module: torch.nn.Module, param_name: str):
@ -26,10 +22,10 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
st = _convert_tensor(tensor) st = _convert_tensor(tensor)
# Replace param with StatefulTensorV2. # Replace param with ColoTensor.
# Need to delete the attribute first since param_name might be # Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with StatefulTensorV2 which is # torch.nn.Parameter and can't be replaced with ColoTensor which is
# not torch.nn.Parameter. # not torch.nn.Parameter.
delattr(module, param_name) delattr(module, param_name)

View File

@ -1,10 +1,6 @@
from numpy import allclose from numpy import allclose
import torch import torch
from torch import nn from colossalai.tensor import ColoTensor
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
# TODO(jiaruifang) auto import
from colossalai.gemini.tensor._ops import *
from colossalai.gemini.tensor.api import _STATEFUL_OPS
from copy import deepcopy from copy import deepcopy
@ -18,8 +14,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim) input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone() input_tensor = input_ref.clone()
sharded_weight = StatefulTensorV2(fc_ref.weight) sharded_weight = ColoTensor(fc_ref.weight)
sharded_bias = StatefulTensorV2(fc_ref.bias) sharded_bias = ColoTensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor # replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight') delattr(fc, 'weight')
@ -45,15 +41,14 @@ def test_linear():
# The test case failed # The test case failed
# def test_uniform(): # def test_uniform():
# t = StatefulTensorV2(torch.zeros(3, 5)) # t = ColoTensor(torch.zeros(3, 5))
# # print(_STATEFUL_OPS)
# torch.nn.init.uniform_(t) # torch.nn.init.uniform_(t)
# print(t) # print(t)
def test_element_wise(): def test_element_wise():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = StatefulTensorV2(t_ref.clone()) t = ColoTensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref) assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))