mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[tensor] customized op returns ColoTensor (#875)
* [tensor] customized op returns ColoTensor * polish * polish code
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from .init import colo_uniform
|
||||
from .linear import colo_linear
|
||||
from .element_wise import colo_mean
|
||||
from .element_wise import *
|
||||
from .layernorm import colo_layernorm
|
||||
from .loss import colo_cross_entropy
|
||||
|
@@ -29,3 +29,22 @@ def register_elementwise_op(op):
|
||||
|
||||
register_elementwise_op(torch.nn.functional.gelu)
|
||||
register_elementwise_op(torch.nn.functional.relu)
|
||||
|
||||
|
||||
@colo_op_impl(torch.sum)
|
||||
def sum_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.sum`.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
if len(args) > 0:
|
||||
input_tensor = args[0]
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if 'input' in kwargs:
|
||||
input_tensor = kwargs['input']
|
||||
# Validate types
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
raise TypeError("input needs to be a ColoTensor")
|
||||
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))
|
||||
|
@@ -1,29 +0,0 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
|
||||
def validate_param(param, param_name):
|
||||
if param is None:
|
||||
raise ValueError(f"param: {param_name} shouldn't be None!")
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.init.uniform_)
|
||||
def colo_uniform(types, args=(), kwargs=None, pg=None):
|
||||
r"""
|
||||
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
|
||||
distribution :math:`\mathcal{U}(a, b)`.
|
||||
Args:
|
||||
sharded_tensor: tensor sharded across devices
|
||||
a: the lower bound of the uniform distribution
|
||||
b: the upper bound of the uniform distribution
|
||||
"""
|
||||
validate_param(kwargs, "kwargs")
|
||||
stateful_tensor = kwargs["tensor"]
|
||||
validate_param(stateful_tensor, "stateful_tensor")
|
||||
a = kwargs['a']
|
||||
validate_param(a, "a")
|
||||
b = kwargs['b']
|
||||
validate_param(b, "b")
|
||||
|
||||
torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b)
|
||||
return stateful_tensor
|
@@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.core import global_context as gpc
|
||||
from packaging import version
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
from colossalai.tensor import ComputePattern
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
def colo_linear(types, args, kwargs, pg):
|
||||
@@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
|
||||
bias = kwargs.get('bias', None)
|
||||
|
||||
if isinstance(bias, ColoTensor):
|
||||
assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator"
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
@@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||
elif weight.shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
@@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
bias_ = bias
|
||||
output = output + bias_
|
||||
output = output + bias
|
||||
return ColoTensor.init_from_torch_tensor(output)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -2,6 +2,7 @@ from enum import Enum
|
||||
from typing import Tuple, List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1DRow = 1
|
||||
TP1DCol = 2
|
||||
@@ -10,6 +11,7 @@ class ComputePattern(Enum):
|
||||
|
||||
|
||||
class ParallelAction(object):
|
||||
|
||||
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
||||
self.priority = priority
|
||||
self.compute_pattern = compute_pattern
|
||||
@@ -24,6 +26,7 @@ class TensorSpec(object):
|
||||
parallel computation pattern of the Operator (Layer).
|
||||
We have to consider the hybrid parallel mode.
|
||||
"""
|
||||
|
||||
# a list of parallel actions.
|
||||
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
|
||||
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
||||
@@ -38,6 +41,7 @@ class TensorSpec(object):
|
||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
|
||||
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
||||
self._parallel_action_list = parallel_action_list
|
||||
self.sort()
|
||||
@@ -56,8 +60,8 @@ class TensorSpec(object):
|
||||
|
||||
def sort(self):
|
||||
if len(self._parallel_action_list) > 0:
|
||||
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
|
||||
|
||||
self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
|
||||
|
||||
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
for parallel_action in self._parallel_action_list:
|
||||
if parallel_action.compute_pattern == compute_pattern:
|
||||
|
Reference in New Issue
Block a user