[tensor] refactor colo-tensor (#992)

* refactor colo-tensor and update linear op

* polish code

* polish code

* update ops and unit tests

* update unit tests

* polish code

* rename dist_spec module

* polish code

* polish code

* remove unneeded import

* fix pipelinable
This commit is contained in:
ver217
2022-05-19 12:44:59 +08:00
committed by GitHub
parent 1467d83edf
commit ad536e308e
27 changed files with 657 additions and 616 deletions

View File

@@ -6,10 +6,10 @@ from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor
from ._ops import *
from .optim.colo_optimizer import ColoOptimizer
from . import dist_spec
from . import distspec
from .dist_spec_mgr import DistSpecManager
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'dist_spec', 'DistSpecManager'
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager'
]

View File

@@ -0,0 +1,12 @@
import torch
from typing import Union, Optional
from colossalai.tensor import ColoTensor
GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
if tensor is not None and not isinstance(tensor, ColoTensor):
tensor = ColoTensor.from_torch_tensor(tensor)
return tensor

View File

@@ -1,64 +1,66 @@
import torch
from typing import Union
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.tensor import dist_spec
from colossalai.tensor import distspec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1.to_dist_spec(dist_spec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group().size()]))
mat1 = mat1.convert_to_dist_spec(
distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]))
# Output:P
partial_output = torch.mm(mat1.torch_tensor(), mat2.torch_tensor())
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor.torch_tensor() + alpha * output
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(mat2.spec.get_process_group())))
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group())))
return output
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
mat1 = reduce_grad(mat1, parallel_action.parallel_mode)
output_parallel = torch.addmm(input_tensor.torch_tensor(),
mat1_torch_tensor,
mat2.torch_tensor(),
beta=beta,
alpha=alpha)
output_spec = TensorSpec(
dist_spec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group().size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if parallel_action.gather_out:
# All-Gather(Output)
output.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
output = output.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
return output
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
@colo_op_impl(torch.addmm)
def colo_addmm(types, args, kwargs, pg):
def colo_addmm(input_tensor: GeneralTensor,
mat1: GeneralTensor,
mat2: GeneralTensor,
*args,
beta: Number = 1,
alpha: Number = 1) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, mat1, mat2 = args[:3]
to_colo_tensor = lambda t: t if isinstance(t, ColoTensor) else ColoTensor.init_from_torch_tensor(t)
input_tensor = to_colo_tensor(input_tensor)
mat2 = to_colo_tensor(mat2)
beta = kwargs.get('beta', 1) if kwargs else 1
alpha = kwargs.get('alpha', 1) if kwargs else 1
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
# building the computing graph, inputs -> op
# if GraphGlobalEnv().graph_building:
@@ -70,17 +72,15 @@ def colo_addmm(types, args, kwargs, pg):
if not mat2.has_spec(): # No Model Parallel Applied
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.init_from_torch_tensor(
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
mode = 'row'
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
mode = 'col'
else:
raise NotImplementedError
ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError

View File

@@ -1,64 +1,28 @@
from copy import copy
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
@colo_op_impl(torch.allclose)
def colo_mean(types, args=(), kwargs=None, pg=None):
a = args[0]
b = args[1]
if isinstance(a, ColoTensor):
a = a.torch_tensor()
elif isinstance(b, ColoTensor):
b = b.torch_tensor()
if kwargs is None:
kwargs = {}
return torch.allclose(a, b, **kwargs)
@colo_op_impl(torch.mean)
def colo_mean(types, args=(), kwargs=None, pg=None):
input_t = args[0]
if isinstance(input_t, ColoTensor):
input_t = input_t.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
from ._utils import GeneralTensor
def register_elementwise_op(op):
@colo_op_impl(op)
def elementwise_op(types, args=(), kwargs=None, pg=None):
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
input_tensor = args[0]
# Validate types
if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor")
return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor):
spec = copy(input_tensor.spec)
return ColoTensor.from_torch_tensor(output, spec=spec)
return ColoTensor.from_torch_tensor(output)
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)
@colo_op_impl(torch.sum)
def sum_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.sum`.
This method computes on either a normal tensor or a sharded tensor.
"""
if len(args) > 0:
input_tensor = args[0]
if kwargs is None:
kwargs = {}
if 'input' in kwargs:
input_tensor = kwargs['input']
# Validate types
if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor")
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))
register_elementwise_op(torch.clone)
register_elementwise_op(torch.Tensor.clone)
register_elementwise_op(torch.Tensor.detach)

View File

@@ -1,31 +1,52 @@
import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
def colo_embedding_1Dcol(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
output_parallel = F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = TensorSpec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
def colo_embedding_1Drow(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
num_embeddings_per_partition = weight.size(0)
@@ -33,53 +54,87 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# Build the mask.
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
(input_tensor.torch_tensor() >= vocab_end_index)
input_mask = (input_tensor < vocab_start_index) | \
(input_tensor >= vocab_end_index)
# Mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
masked_input = input_tensor.clone() - vocab_start_index
masked_input[input_mask] = 0
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(), *args, **kwargs)
partial_output = F.embedding(masked_input,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output
@colo_op_impl(torch.nn.functional.embedding)
def colo_embedding(types, args, kwargs, pg):
def colo_embedding_1d(mode: str,
input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
return funcs[mode](input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
@colo_op_impl(F.embedding)
def colo_embedding(input_tensor: GeneralTensor,
weight: GeneralTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method looks up an embedding table.
"""
input_tensor = args[0]
weight = args[1]
args = args[2:]
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if not isinstance(weight, ColoTensor):
weight = ColoTensor.init_from_torch_tensor(weight)
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
return ColoTensor.init_from_torch_tensor(output)
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row():
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
mode = 'row'
elif weight.spec.is_1D_col():
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
mode = 'col'
else:
raise NotImplementedError
return colo_embedding_1d(mode,
input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
else:
raise NotImplementedError

View File

@@ -1,39 +1,24 @@
import torch
import torch.nn.functional as F
from typing import List, Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, dist_spec
from colossalai.tensor import ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(torch.nn.functional.layer_norm)
def colo_layernorm(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
normalized_shape = args[1]
if arg_num > 2:
weight = args[3]
if arg_num > 3:
bias = args[4]
if arg_num > 4:
eps = args[5]
@colo_op_impl(F.layer_norm)
def colo_layernorm(
input_tensor: GeneralTensor,
normalized_shape: List[int],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
eps: float = 1e-5,
):
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
if 'input' in kwargs:
input_tensor = kwargs['input']
if 'weight' in kwargs:
weight = kwargs['weight']
if 'bias' in kwargs:
bias = kwargs['bias']
if 'eps' in kwargs:
eps = kwargs['eps']
# TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group()))
if isinstance(input_tensor, ColoTensor):
# TODO (ver217): check input dist spec
input_tensor.to_dist_spec(dist_spec.replicate(input_tensor.spec.get_process_group()))
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.spec)
return output

View File

@@ -1,108 +1,89 @@
import torch
import torch.nn.functional as F
import torch.distributed as dist
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
input_tensor.to_dist_spec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]))
input_tensor = input_tensor.convert_to_dist_spec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]))
# Output:P
partial_output = torch.nn.functional.linear(input_tensor.torch_tensor(), weight.torch_tensor())
partial_output = F.linear(input_tensor, weight)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# Bias
if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias.torch_tensor()
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
if bias is not None:
bias = bias.torch_tensor()
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(
output_parallel,
spec=TensorSpec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
if parallel_action.gather_out:
# All-Gather(Output)
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
return funcs[mode](input_tensor, weight, bias)
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor = args[0]
weight = args[1]
if version.parse(torch.__version__) > version.parse("1.11.0"):
if len(args) == 3:
bias = args[2]
else:
bias = None
else:
bias = kwargs.get('bias', None)
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if not isinstance(weight, ColoTensor):
weight = ColoTensor.init_from_torch_tensor(weight)
if bias is not None and not isinstance(bias, ColoTensor):
bias = ColoTensor.init_from_torch_tensor(bias)
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', [weight, bias])
cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
if bias is not None:
bias = bias.torch_tensor()
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
mode = 'row'
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
mode = 'col'
else:
raise NotImplementedError
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
else:
raise NotImplementedError
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor

View File

@@ -1,40 +1,37 @@
from colossalai.tensor.dist_spec import DistPlacementPattern
import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(torch.nn.functional.cross_entropy)
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
target = args[1]
if arg_num > 2:
weight = args[2]
if 'input' in kwargs:
input_tensor = kwargs.pop('input')
if 'target' in kwargs:
target = kwargs.pop('target')
if 'weight' in kwargs:
weight = kwargs.pop('weight')
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if isinstance(target, ColoTensor):
target = target.torch_tensor()
@colo_op_impl(F.cross_entropy)
def colo_cross_entropy(input_tensor: GeneralTensor,
target: GeneralTensor,
weight: Optional[GeneralTensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0):
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
output = F.cross_entropy(input_tensor,
target,
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
if input_tensor.spec.is_1D_col():
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
target))
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output)
else:
raise NotImplementedError
else:

View File

@@ -1,6 +1,8 @@
from .colo_tensor import ColoTensor
from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
class ColoParameter(ColoTensor):
@@ -8,21 +10,26 @@ class ColoParameter(ColoTensor):
"""
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)
self._type = TensorType.MODEL
def __new__(cls,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __new__(cls, *args, **kwargs):
t = super(ColoParameter, cls).__new__(cls)
t._type = TensorType.MODEL
return t
def __init__(self,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.MODEL
self._graph_node = None
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
colo_p = ColoParameter(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_p
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor

View File

@@ -1,16 +1,23 @@
from .op_wrapper import _COLOSSAL_OPS
from copy import copy
import torch
from typing import Tuple, Optional, Callable, Union
from numpy import product
from colossalai.tensor import TensorSpec
from .const import TensorType
from colossalai.tensor import dist_spec
from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.dist_spec import _DistSpec
from colossalai.tensor.distspec import _DistSpec
from torch.overrides import get_default_nowrap_functions
class ColoTensor(object):
def _convert_output(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output = ColoTensor.from_torch_tensor(output)
elif isinstance(output, (list, tuple)):
output = type(output)(_convert_output(o) for o in output)
return output
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
@@ -18,120 +25,23 @@ class ColoTensor(object):
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def __new__(cls, *args, **kwargs):
return super(ColoTensor, cls).__new__(cls)
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
spec: TensorSpec = TensorSpec(dist_spec.replicate())):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.NONMODEL
self._graph_node = None
def __getitem__(self, key):
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
@property
def spec(self) -> TensorSpec:
return self._spec
@property
def shard_pattern(self):
return self._shard_pattern
@property
def data(self):
return self._torch_tensor.data
@data.setter
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
if isinstance(tensor, ColoTensor):
self._torch_tensor.data = tensor.data
elif isinstance(tensor, torch.Tensor):
self._torch_tensor.data = tensor
else:
raise NotImplementedError
@property
def grad(self):
return self._torch_tensor.grad
@property
def size(self):
return self._size
@property
def shape(self):
return torch.Size(self._size)
@property
def device(self):
return self._torch_tensor.device
def size(self, dim=None):
if dim is None:
return self.shape
return self._size[dim]
def dim(self):
return len(self._size)
def normal_(self, mean=0., std=1.):
torch_tensor = self.torch_tensor()
return torch_tensor.normal_(mean=mean, std=std)
def numel(self):
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor,
save_payload=True,
spec: TensorSpec = TensorSpec(dist_spec.replicate())) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0),
spec=spec)
return colo_t
def del_torch_tensor(self, save_shape=False) -> None:
"""
delete the payload of the torch tensor.
Args:
save_shape (bool, optional): if saving the shape of the torch_tensor.
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
Defaults to False.
"""
if not save_shape:
self._size = (0,)
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor.numel() == 0:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
pin_memory=self._pin_memory,
requires_grad=self._requires_grad,
device=self._device)
return self._torch_tensor
def set_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self.to_dist_spec(spec.dist_spec)
self.convert_to_dist_spec_(spec.dist_spec)
self._spec = spec
def has_spec(self) -> bool:
@@ -142,89 +52,51 @@ class ColoTensor(object):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if not all(issubclass(cls, t) for t in types):
return NotImplemented
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
for arg in args:
if isinstance(arg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
func = _COLOSSAL_OPS[func]
for kwarg in kwargs.values():
if isinstance(kwarg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
else:
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
return _convert_output(ret)
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return cls._filter_outputs_with_colo(func(*args, **kwargs))
def __repr__(self):
return f'ColoTensor: {super().__repr__()}'
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def __add__(self, o) -> "ColoTensor":
if isinstance(o, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
elif isinstance(o, (torch.Tensor, int, float)):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
else:
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
__radd__ = __add__
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name):
def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs):
# transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return self._filter_outputs_with_colo(func(*args, **kwargs))
return execute_func
if hasattr(self._torch_tensor, name) == False:
raise AttributeError
attr = getattr(self._torch_tensor, name)
if isinstance(attr, Callable):
return replace_tensor_with_colo(attr)
else:
return attr
@classmethod
def _filter_outputs_with_colo(cls, outputs):
if outputs is None: # return None
return None
elif type(outputs) is not tuple: # num of return val = 1
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
else: # num of return val > 1
return tuple([
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
for output in outputs
])
def __mul__(self, other) -> "ColoTensor":
if isinstance(other, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
elif isinstance(other, (torch.Tensor, int, float)):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other)
else:
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
__rmul__ = __mul__
def to_dist_spec(self, dist_spec: _DistSpec) -> None:
self._torch_tensor = DistSpecManager.handle_trans_spec(self.torch_tensor(), self.spec.dist_spec, dist_spec)
if self._torch_tensor.is_leaf:
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
spec = copy(self._spec)
spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, spec)
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor

View File

@@ -1,4 +1,4 @@
from colossalai.tensor.dist_spec import _DistSpec
from colossalai.tensor.distspec import _DistSpec
from colossalai.nn.layer.utils import divide
from numpy import prod
from contextlib import contextmanager
@@ -53,7 +53,7 @@ class DistSpecManager:
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
and dist_spec.process_group is not None:
raise NotImplementedError
return tensor
@@ -66,7 +66,7 @@ class DistSpecManager:
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
and dist_spec.process_group is not None:
raise NotImplementedError
return DistSpecManager._gather(tensor, old_dist_spec)

View File

@@ -9,11 +9,6 @@ _COLOSSAL_OPS: Dict[str, Callable] = {}
def _register_colo_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 4:
raise TypeError(f'Custom stateful op function expects signature: '
f'(types, args, kwargs, process_group), but received '
f'signature: {signature(func)}')
global _COLOSSAL_OPS
_COLOSSAL_OPS[op] = func

View File

@@ -1,7 +1,8 @@
import torch.distributed as dist
from enum import Enum
from typing import List
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
class ComputePattern(Enum):
@@ -77,6 +78,9 @@ class TensorSpec(object):
def get_process_group(self):
return self.dist_spec.process_group
def get_process_group_size(self):
return dist.get_world_size(self.dist_spec.process_group)
def get_placement(self):
return self.dist_spec.placement