mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-06 06:02:16 +00:00
[tensor] reorganize files (#820)
This commit is contained in:
parent
ab962b9735
commit
0ce8924ceb
@ -1,3 +0,0 @@
|
|||||||
from .init import stateful_uniform
|
|
||||||
from .linear import stateful_linear
|
|
||||||
from .element_wise import stateful_mean
|
|
@ -1,17 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Custom sharded ops
|
|
||||||
_STATEFUL_OPS: Dict[str, Callable] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _register_stateful_op(op, func):
|
|
||||||
from inspect import signature
|
|
||||||
if len(signature(func).parameters) != 4:
|
|
||||||
raise TypeError(f'Custom stateful op function expects signature: '
|
|
||||||
f'(types, args, kwargs, process_group), but received '
|
|
||||||
f'signature: {signature(func)}')
|
|
||||||
global _STATEFUL_OPS
|
|
||||||
_STATEFUL_OPS[op] = func
|
|
7
colossalai/tensor/__init__.py
Normal file
7
colossalai/tensor/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .op_wrapper import (
|
||||||
|
colo_op_impl,)
|
||||||
|
from .colo_tensor import ColoTensor
|
||||||
|
from .utils import convert_parameter
|
||||||
|
from ._ops import *
|
||||||
|
|
||||||
|
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']
|
3
colossalai/tensor/_ops/__init__.py
Normal file
3
colossalai/tensor/_ops/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .init import colo_uniform
|
||||||
|
from .linear import colo_linear
|
||||||
|
from .element_wise import colo_mean
|
@ -1,17 +1,17 @@
|
|||||||
import torch
|
import torch
|
||||||
from colossalai.gemini.tensor import stateful_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
@stateful_op_impl(torch.mean)
|
@colo_op_impl(torch.mean)
|
||||||
def stateful_mean(types, args=(), kwargs=None, pg=None):
|
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||||
stateful_tensor = args[0]
|
stateful_tensor = args[0]
|
||||||
return torch.mean(stateful_tensor.torch_tensor())
|
return torch.mean(stateful_tensor.torch_tensor())
|
||||||
|
|
||||||
|
|
||||||
def register_elementwise_op(op):
|
def register_elementwise_op(op):
|
||||||
|
|
||||||
@stateful_op_impl(op)
|
@colo_op_impl(op)
|
||||||
def elementwise_op(types, args=(), kwargs=None, pg=None):
|
def elementwise_op(types, args=(), kwargs=None, pg=None):
|
||||||
"""
|
"""
|
||||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||||
@ -20,8 +20,8 @@ def register_elementwise_op(op):
|
|||||||
"""
|
"""
|
||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
# Validate types
|
# Validate types
|
||||||
if not isinstance(input_tensor, StatefulTensorV2):
|
if not isinstance(input_tensor, ColoTensor):
|
||||||
raise TypeError("input needs to be a StatefulTensorV2")
|
raise TypeError("input needs to be a ColoTensor")
|
||||||
return op(input_tensor.torch_tensor())
|
return op(input_tensor.torch_tensor())
|
||||||
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from colossalai.gemini.tensor import stateful_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
|
|
||||||
|
|
||||||
def validate_param(param, param_name):
|
def validate_param(param, param_name):
|
||||||
@ -7,8 +7,8 @@ def validate_param(param, param_name):
|
|||||||
raise ValueError(f"param: {param_name} shouldn't be None!")
|
raise ValueError(f"param: {param_name} shouldn't be None!")
|
||||||
|
|
||||||
|
|
||||||
@stateful_op_impl(torch.nn.init.uniform_)
|
@colo_op_impl(torch.nn.init.uniform_)
|
||||||
def stateful_uniform(types, args=(), kwargs=None, pg=None):
|
def colo_uniform(types, args=(), kwargs=None, pg=None):
|
||||||
r"""
|
r"""
|
||||||
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
|
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
|
||||||
distribution :math:`\mathcal{U}(a, b)`.
|
distribution :math:`\mathcal{U}(a, b)`.
|
@ -1,11 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
from colossalai.gemini.tensor import stateful_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from ..stateful_tensor import StatefulTensorV2
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
@stateful_op_impl(torch.nn.functional.linear)
|
@colo_op_impl(torch.nn.functional.linear)
|
||||||
def stateful_linear(types, args, kwargs, pg):
|
def colo_linear(types, args, kwargs, pg):
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||||
This method computes a linear.
|
This method computes a linear.
|
||||||
"""
|
"""
|
||||||
@ -19,11 +19,11 @@ def stateful_linear(types, args, kwargs, pg):
|
|||||||
bias = None
|
bias = None
|
||||||
else:
|
else:
|
||||||
bias = kwargs.get('bias', None)
|
bias = kwargs.get('bias', None)
|
||||||
if isinstance(bias, StatefulTensorV2):
|
if isinstance(bias, ColoTensor):
|
||||||
bias = bias.torch_tensor()
|
bias = bias.torch_tensor()
|
||||||
|
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
if isinstance(weight, StatefulTensorV2):
|
if isinstance(weight, ColoTensor):
|
||||||
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
|
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
@ -1,11 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
from .api import _STATEFUL_OPS
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
|
||||||
|
|
||||||
class StatefulTensorV2(object):
|
class ColoTensor(object):
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
return super(StatefulTensorV2, cls).__new__(cls)
|
return super(ColoTensor, cls).__new__(cls)
|
||||||
|
|
||||||
def __init__(self, t: torch.Tensor) -> None:
|
def __init__(self, t: torch.Tensor) -> None:
|
||||||
self._torch_tensor = t
|
self._torch_tensor = t
|
||||||
@ -15,16 +15,15 @@ class StatefulTensorV2(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
global _STATEFUL_OPS
|
global _COLOSSAL_OPS
|
||||||
if func in _STATEFUL_OPS:
|
if func in _COLOSSAL_OPS:
|
||||||
# Find StatefulTensorV2 instance to get process_group.
|
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, StatefulTensorV2):
|
if isinstance(arg, ColoTensor):
|
||||||
return _STATEFUL_OPS[func](types, args, kwargs, None)
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||||
|
|
||||||
for kwarg in kwargs.values():
|
for kwarg in kwargs.values():
|
||||||
if isinstance(kwarg, StatefulTensorV2):
|
if isinstance(kwarg, ColoTensor):
|
||||||
return _STATEFUL_OPS[func](types, args, kwargs, None)
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||||
|
|
||||||
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
|
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
|
||||||
f"kwargs: {kwargs} not supported for StatefulTensorV2!")
|
f"kwargs: {kwargs} not supported for ColoTensor!")
|
@ -1,24 +1,39 @@
|
|||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
)
|
||||||
import functools
|
import functools
|
||||||
from .api import (
|
|
||||||
_register_stateful_op,)
|
# Custom sharded ops
|
||||||
|
_COLOSSAL_OPS: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
|
||||||
def stateful_op_impl(func):
|
def _register_colo_op(op, func):
|
||||||
|
from inspect import signature
|
||||||
|
if len(signature(func).parameters) != 4:
|
||||||
|
raise TypeError(f'Custom stateful op function expects signature: '
|
||||||
|
f'(types, args, kwargs, process_group), but received '
|
||||||
|
f'signature: {signature(func)}')
|
||||||
|
global _COLOSSAL_OPS
|
||||||
|
_COLOSSAL_OPS[op] = func
|
||||||
|
|
||||||
|
|
||||||
|
def colo_op_impl(func):
|
||||||
"""
|
"""
|
||||||
Provides a way for users to write their own custom operator. This
|
Provides a way for users to write their own custom operator. This
|
||||||
can be used to override existing StatefulTensorV2 operators or write a new
|
can be used to override existing ColoTensor operators or write a new
|
||||||
one not supported by StatefulTensorV2. If the operator in question is covered
|
one not supported by ColoTensor. If the operator in question is covered
|
||||||
by ``__torch_function__`` dispatch and has a StatefulTensorV2 as any of its
|
by ``__torch_function__`` dispatch and has a ColoTensor as any of its
|
||||||
parameters, the function provided will be invoked for that operator.
|
parameters, the function provided will be invoked for that operator.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
>>> @stateful_op_impl(torch.nn.functional.linear)
|
>>> @colo_op_impl(torch.nn.functional.linear)
|
||||||
>>> def my_custom_linear(types, args, kwargs, process_group):
|
>>> def my_custom_linear(types, args, kwargs, process_group):
|
||||||
>>> ....
|
>>> ....
|
||||||
>>>
|
>>>
|
||||||
>>> input = torch.rand(10, 32)
|
>>> input = torch.rand(10, 32)
|
||||||
>>> weight = StatefulTensorV2(torch.rand(32, 16))
|
>>> weight = ColoTensor(torch.rand(32, 16))
|
||||||
>>> bias = StatefulTensorV2(torch.rand(16))
|
>>> bias = ColoTensor(torch.rand(16))
|
||||||
>>> # This will call `my_custom_linear` instead of the default.
|
>>> # This will call `my_custom_linear` instead of the default.
|
||||||
>>> torch.nn.functional.linear(input, weight, bias)
|
>>> torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
@ -32,7 +47,7 @@ def stateful_op_impl(func):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator_sharded_func(wrapped_func):
|
def decorator_sharded_func(wrapped_func):
|
||||||
_register_stateful_op(func, wrapped_func)
|
_register_colo_op(func, wrapped_func)
|
||||||
|
|
||||||
@functools.wraps(wrapped_func)
|
@functools.wraps(wrapped_func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
@ -1,14 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed import distributed_c10d
|
|
||||||
|
|
||||||
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
def _convert_tensor(tensor: torch.Tensor) -> StatefulTensorV2:
|
def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
|
||||||
if not tensor.is_contiguous():
|
return ColoTensor(tensor)
|
||||||
raise ValueError('input tensor is not a contiguous Tensor')
|
|
||||||
return StatefulTensorV2(tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_parameter(module: torch.nn.Module, param_name: str):
|
def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||||
@ -26,10 +22,10 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
|
|||||||
|
|
||||||
st = _convert_tensor(tensor)
|
st = _convert_tensor(tensor)
|
||||||
|
|
||||||
# Replace param with StatefulTensorV2.
|
# Replace param with ColoTensor.
|
||||||
|
|
||||||
# Need to delete the attribute first since param_name might be
|
# Need to delete the attribute first since param_name might be
|
||||||
# torch.nn.Parameter and can't be replaced with StatefulTensorV2 which is
|
# torch.nn.Parameter and can't be replaced with ColoTensor which is
|
||||||
# not torch.nn.Parameter.
|
# not torch.nn.Parameter.
|
||||||
delattr(module, param_name)
|
delattr(module, param_name)
|
||||||
|
|
@ -1,10 +1,6 @@
|
|||||||
from numpy import allclose
|
from numpy import allclose
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from colossalai.tensor import ColoTensor
|
||||||
from colossalai.gemini.tensor.stateful_tensor import StatefulTensorV2
|
|
||||||
# TODO(jiaruifang) auto import
|
|
||||||
from colossalai.gemini.tensor._ops import *
|
|
||||||
from colossalai.gemini.tensor.api import _STATEFUL_OPS
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
@ -18,8 +14,8 @@ def test_linear():
|
|||||||
input_ref = torch.randn(1, in_dim)
|
input_ref = torch.randn(1, in_dim)
|
||||||
input_tensor = input_ref.clone()
|
input_tensor = input_ref.clone()
|
||||||
|
|
||||||
sharded_weight = StatefulTensorV2(fc_ref.weight)
|
sharded_weight = ColoTensor(fc_ref.weight)
|
||||||
sharded_bias = StatefulTensorV2(fc_ref.bias)
|
sharded_bias = ColoTensor(fc_ref.bias)
|
||||||
|
|
||||||
# replace the torch nn.Parameters with ShardedTensor
|
# replace the torch nn.Parameters with ShardedTensor
|
||||||
delattr(fc, 'weight')
|
delattr(fc, 'weight')
|
||||||
@ -45,15 +41,14 @@ def test_linear():
|
|||||||
|
|
||||||
# The test case failed
|
# The test case failed
|
||||||
# def test_uniform():
|
# def test_uniform():
|
||||||
# t = StatefulTensorV2(torch.zeros(3, 5))
|
# t = ColoTensor(torch.zeros(3, 5))
|
||||||
# # print(_STATEFUL_OPS)
|
|
||||||
# torch.nn.init.uniform_(t)
|
# torch.nn.init.uniform_(t)
|
||||||
# print(t)
|
# print(t)
|
||||||
|
|
||||||
|
|
||||||
def test_element_wise():
|
def test_element_wise():
|
||||||
t_ref = torch.randn(3, 5)
|
t_ref = torch.randn(3, 5)
|
||||||
t = StatefulTensorV2(t_ref.clone())
|
t = ColoTensor(t_ref.clone())
|
||||||
assert torch.mean(t) == torch.mean(t_ref)
|
assert torch.mean(t) == torch.mean(t_ref)
|
||||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
||||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
Loading…
Reference in New Issue
Block a user