mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[tensor] reorganize files (#820)
This commit is contained in:
7
colossalai/tensor/__init__.py
Normal file
7
colossalai/tensor/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .op_wrapper import (
|
||||
colo_op_impl,)
|
||||
from .colo_tensor import ColoTensor
|
||||
from .utils import convert_parameter
|
||||
from ._ops import *
|
||||
|
||||
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']
|
3
colossalai/tensor/_ops/__init__.py
Normal file
3
colossalai/tensor/_ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .init import colo_uniform
|
||||
from .linear import colo_linear
|
||||
from .element_wise import colo_mean
|
29
colossalai/tensor/_ops/element_wise.py
Normal file
29
colossalai/tensor/_ops/element_wise.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.mean)
|
||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||
stateful_tensor = args[0]
|
||||
return torch.mean(stateful_tensor.torch_tensor())
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
|
||||
@colo_op_impl(op)
|
||||
def elementwise_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
input_tensor = args[0]
|
||||
# Validate types
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
raise TypeError("input needs to be a ColoTensor")
|
||||
return op(input_tensor.torch_tensor())
|
||||
|
||||
|
||||
register_elementwise_op(torch.nn.functional.gelu)
|
||||
register_elementwise_op(torch.nn.functional.relu)
|
29
colossalai/tensor/_ops/init.py
Normal file
29
colossalai/tensor/_ops/init.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
|
||||
def validate_param(param, param_name):
|
||||
if param is None:
|
||||
raise ValueError(f"param: {param_name} shouldn't be None!")
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.init.uniform_)
|
||||
def colo_uniform(types, args=(), kwargs=None, pg=None):
|
||||
r"""
|
||||
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
|
||||
distribution :math:`\mathcal{U}(a, b)`.
|
||||
Args:
|
||||
sharded_tensor: tensor sharded across devices
|
||||
a: the lower bound of the uniform distribution
|
||||
b: the upper bound of the uniform distribution
|
||||
"""
|
||||
validate_param(kwargs, "kwargs")
|
||||
stateful_tensor = kwargs["tensor"]
|
||||
validate_param(stateful_tensor, "stateful_tensor")
|
||||
a = kwargs['a']
|
||||
validate_param(a, "a")
|
||||
b = kwargs['b']
|
||||
validate_param(b, "b")
|
||||
|
||||
torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b)
|
||||
return stateful_tensor
|
29
colossalai/tensor/_ops/linear.py
Normal file
29
colossalai/tensor/_ops/linear.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from packaging import version
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
def colo_linear(types, args, kwargs, pg):
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if version.parse(torch.__version__) > version.parse("1.11.0"):
|
||||
if len(args) == 3:
|
||||
bias = args[2]
|
||||
else:
|
||||
bias = None
|
||||
else:
|
||||
bias = kwargs.get('bias', None)
|
||||
if isinstance(bias, ColoTensor):
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
if isinstance(weight, ColoTensor):
|
||||
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
29
colossalai/tensor/colo_tensor.py
Normal file
29
colossalai/tensor/colo_tensor.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super(ColoTensor, cls).__new__(cls)
|
||||
|
||||
def __init__(self, t: torch.Tensor) -> None:
|
||||
self._torch_tensor = t
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
return self._torch_tensor
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
global _COLOSSAL_OPS
|
||||
if func in _COLOSSAL_OPS:
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||
|
||||
for kwarg in kwargs.values():
|
||||
if isinstance(kwarg, ColoTensor):
|
||||
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||
|
||||
raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} not supported for ColoTensor!")
|
58
colossalai/tensor/op_wrapper.py
Normal file
58
colossalai/tensor/op_wrapper.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
)
|
||||
import functools
|
||||
|
||||
# Custom sharded ops
|
||||
_COLOSSAL_OPS: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def _register_colo_op(op, func):
|
||||
from inspect import signature
|
||||
if len(signature(func).parameters) != 4:
|
||||
raise TypeError(f'Custom stateful op function expects signature: '
|
||||
f'(types, args, kwargs, process_group), but received '
|
||||
f'signature: {signature(func)}')
|
||||
global _COLOSSAL_OPS
|
||||
_COLOSSAL_OPS[op] = func
|
||||
|
||||
|
||||
def colo_op_impl(func):
|
||||
"""
|
||||
Provides a way for users to write their own custom operator. This
|
||||
can be used to override existing ColoTensor operators or write a new
|
||||
one not supported by ColoTensor. If the operator in question is covered
|
||||
by ``__torch_function__`` dispatch and has a ColoTensor as any of its
|
||||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example::
|
||||
>>> @colo_op_impl(torch.nn.functional.linear)
|
||||
>>> def my_custom_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
>>> input = torch.rand(10, 32)
|
||||
>>> weight = ColoTensor(torch.rand(32, 16))
|
||||
>>> bias = ColoTensor(torch.rand(16))
|
||||
>>> # This will call `my_custom_linear` instead of the default.
|
||||
>>> torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
The types, args and kwargs parameters are the same parameters that are
|
||||
passed to ``__torch_function__`` dispatch API
|
||||
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
|
||||
|
||||
Args:
|
||||
func(Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
_register_colo_op(func, wrapped_func)
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_sharded_func
|
33
colossalai/tensor/utils.py
Normal file
33
colossalai/tensor/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
def _convert_tensor(tensor: torch.Tensor) -> ColoTensor:
|
||||
return ColoTensor(tensor)
|
||||
|
||||
|
||||
def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||
# Perform some validation first.
|
||||
if not hasattr(module, param_name):
|
||||
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
|
||||
|
||||
tensor = getattr(module, param_name)
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise ValueError(
|
||||
f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
|
||||
|
||||
st = _convert_tensor(tensor)
|
||||
|
||||
# Replace param with ColoTensor.
|
||||
|
||||
# Need to delete the attribute first since param_name might be
|
||||
# torch.nn.Parameter and can't be replaced with ColoTensor which is
|
||||
# not torch.nn.Parameter.
|
||||
delattr(module, param_name)
|
||||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
Reference in New Issue
Block a user