mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[tensor] refactor colo-tensor (#992)
* refactor colo-tensor and update linear op * polish code * polish code * update ops and unit tests * update unit tests * polish code * rename dist_spec module * polish code * polish code * remove unneeded import * fix pipelinable
This commit is contained in:
@@ -6,10 +6,10 @@ from .colo_parameter import ColoParameter
|
||||
from .utils import convert_parameter, named_params_with_colotensor
|
||||
from ._ops import *
|
||||
from .optim.colo_optimizer import ColoOptimizer
|
||||
from . import dist_spec
|
||||
from . import distspec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'dist_spec', 'DistSpecManager'
|
||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager'
|
||||
]
|
||||
|
12
colossalai/tensor/_ops/_utils.py
Normal file
12
colossalai/tensor/_ops/_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from typing import Union, Optional
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||||
Number = Union[int, float]
|
||||
|
||||
|
||||
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
|
||||
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
return tensor
|
@@ -1,64 +1,66 @@
|
||||
import torch
|
||||
from typing import Union
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
|
||||
from colossalai.tensor import dist_spec
|
||||
from colossalai.tensor import distspec
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||
|
||||
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
||||
alpha: Union[int, float]) -> ColoTensor:
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1.to_dist_spec(dist_spec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group().size()]))
|
||||
mat1 = mat1.convert_to_dist_spec(
|
||||
distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1.torch_tensor(), mat2.torch_tensor())
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
# input
|
||||
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor.torch_tensor() + alpha * output
|
||||
output = ColoTensor.init_from_torch_tensor(output,
|
||||
spec=TensorSpec(dist_spec.replicate(mat2.spec.get_process_group())))
|
||||
output = beta * input_tensor + alpha * output
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
|
||||
alpha: Union[int, float]) -> ColoTensor:
|
||||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, parallel_action.parallel_mode)
|
||||
|
||||
output_parallel = torch.addmm(input_tensor.torch_tensor(),
|
||||
mat1_torch_tensor,
|
||||
mat2.torch_tensor(),
|
||||
beta=beta,
|
||||
alpha=alpha)
|
||||
output_spec = TensorSpec(
|
||||
dist_spec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group().size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
|
||||
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
if parallel_action.gather_out:
|
||||
# All-Gather(Output)
|
||||
output.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||
output = output.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
|
||||
return output
|
||||
|
||||
|
||||
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
|
||||
return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
|
||||
|
||||
|
||||
@colo_op_impl(torch.addmm)
|
||||
def colo_addmm(types, args, kwargs, pg):
|
||||
def colo_addmm(input_tensor: GeneralTensor,
|
||||
mat1: GeneralTensor,
|
||||
mat2: GeneralTensor,
|
||||
*args,
|
||||
beta: Number = 1,
|
||||
alpha: Number = 1) -> ColoTensor:
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
input_tensor, mat1, mat2 = args[:3]
|
||||
to_colo_tensor = lambda t: t if isinstance(t, ColoTensor) else ColoTensor.init_from_torch_tensor(t)
|
||||
input_tensor = to_colo_tensor(input_tensor)
|
||||
mat2 = to_colo_tensor(mat2)
|
||||
beta = kwargs.get('beta', 1) if kwargs else 1
|
||||
alpha = kwargs.get('alpha', 1) if kwargs else 1
|
||||
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
|
||||
|
||||
# building the computing graph, inputs -> op
|
||||
# if GraphGlobalEnv().graph_building:
|
||||
@@ -70,17 +72,15 @@ def colo_addmm(types, args, kwargs, pg):
|
||||
if not mat2.has_spec(): # No Model Parallel Applied
|
||||
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
|
||||
ret_tensor = ColoTensor.init_from_torch_tensor(
|
||||
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
|
||||
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
|
||||
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
|
||||
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
|
||||
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
|
||||
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
|
||||
mode = 'row'
|
||||
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
|
||||
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -1,64 +1,28 @@
|
||||
from copy import copy
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.allclose)
|
||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||
a = args[0]
|
||||
b = args[1]
|
||||
|
||||
if isinstance(a, ColoTensor):
|
||||
a = a.torch_tensor()
|
||||
elif isinstance(b, ColoTensor):
|
||||
b = b.torch_tensor()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return torch.allclose(a, b, **kwargs)
|
||||
|
||||
|
||||
@colo_op_impl(torch.mean)
|
||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||
input_t = args[0]
|
||||
if isinstance(input_t, ColoTensor):
|
||||
input_t = input_t.torch_tensor()
|
||||
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
|
||||
from ._utils import GeneralTensor
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
|
||||
@colo_op_impl(op)
|
||||
def elementwise_op(types, args=(), kwargs=None, pg=None):
|
||||
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
input_tensor = args[0]
|
||||
# Validate types
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
raise TypeError("input needs to be a ColoTensor")
|
||||
return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
|
||||
output = op(input_tensor, *args, **kwargs)
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
spec = copy(input_tensor.spec)
|
||||
return ColoTensor.from_torch_tensor(output, spec=spec)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
|
||||
|
||||
register_elementwise_op(torch.nn.functional.gelu)
|
||||
register_elementwise_op(torch.nn.functional.relu)
|
||||
|
||||
|
||||
@colo_op_impl(torch.sum)
|
||||
def sum_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.sum`.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
if len(args) > 0:
|
||||
input_tensor = args[0]
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if 'input' in kwargs:
|
||||
input_tensor = kwargs['input']
|
||||
# Validate types
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
raise TypeError("input needs to be a ColoTensor")
|
||||
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))
|
||||
register_elementwise_op(torch.clone)
|
||||
register_elementwise_op(torch.Tensor.clone)
|
||||
register_elementwise_op(torch.Tensor.detach)
|
||||
|
@@ -1,31 +1,52 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||
def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
|
||||
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
output_spec = TensorSpec(
|
||||
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
|
||||
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
|
||||
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
return output
|
||||
|
||||
|
||||
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
num_embeddings_per_partition = weight.size(0)
|
||||
@@ -33,53 +54,87 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
|
||||
# Build the mask.
|
||||
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
|
||||
(input_tensor.torch_tensor() >= vocab_end_index)
|
||||
input_mask = (input_tensor < vocab_start_index) | \
|
||||
(input_tensor >= vocab_end_index)
|
||||
# Mask the input.
|
||||
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
|
||||
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
|
||||
masked_input = input_tensor.clone() - vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(), *args, **kwargs)
|
||||
partial_output = F.embedding(masked_input,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
|
||||
# Mask the output embedding.
|
||||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
output = ColoTensor.init_from_torch_tensor(output,
|
||||
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.embedding)
|
||||
def colo_embedding(types, args, kwargs, pg):
|
||||
def colo_embedding_1d(mode: str,
|
||||
input_tensor: ColoTensor,
|
||||
weight: ColoTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
|
||||
return funcs[mode](input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
|
||||
|
||||
@colo_op_impl(F.embedding)
|
||||
def colo_embedding(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False):
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
||||
This method looks up an embedding table.
|
||||
"""
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
args = args[2:]
|
||||
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
|
||||
|
||||
if not isinstance(weight, ColoTensor):
|
||||
weight = ColoTensor.init_from_torch_tensor(weight)
|
||||
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
|
||||
|
||||
# Handle differen parallel actions.
|
||||
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
weight = weight.torch_tensor()
|
||||
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
|
||||
return ColoTensor.init_from_torch_tensor(output)
|
||||
return ColoTensor.from_torch_tensor(
|
||||
F.embedding(input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse))
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_row():
|
||||
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
|
||||
mode = 'row'
|
||||
elif weight.spec.is_1D_col():
|
||||
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return colo_embedding_1d(mode,
|
||||
input_tensor,
|
||||
weight,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,39 +1,24 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor, dist_spec
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.layer_norm)
|
||||
def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
||||
arg_num = len(args)
|
||||
if arg_num > 0:
|
||||
input_tensor = args[0]
|
||||
if arg_num > 1:
|
||||
normalized_shape = args[1]
|
||||
if arg_num > 2:
|
||||
weight = args[3]
|
||||
if arg_num > 3:
|
||||
bias = args[4]
|
||||
if arg_num > 4:
|
||||
eps = args[5]
|
||||
@colo_op_impl(F.layer_norm)
|
||||
def colo_layernorm(
|
||||
input_tensor: GeneralTensor,
|
||||
normalized_shape: List[int],
|
||||
weight: Optional[GeneralTensor] = None,
|
||||
bias: Optional[GeneralTensor] = None,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
if 'input' in kwargs:
|
||||
input_tensor = kwargs['input']
|
||||
if 'weight' in kwargs:
|
||||
weight = kwargs['weight']
|
||||
if 'bias' in kwargs:
|
||||
bias = kwargs['bias']
|
||||
if 'eps' in kwargs:
|
||||
eps = kwargs['eps']
|
||||
# TODO (ver217): check dist spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group()))
|
||||
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
# TODO (ver217): check input dist spec
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(input_tensor.spec.get_process_group()))
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
if isinstance(bias, ColoTensor):
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
return ColoTensor.init_from_torch_tensor(
|
||||
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(output, input_tensor.spec)
|
||||
return output
|
||||
|
@@ -1,108 +1,89 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.core import global_context as gpc
|
||||
from packaging import version
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
input_tensor.to_dist_spec(
|
||||
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.nn.functional.linear(input_tensor.torch_tensor(), weight.torch_tensor())
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias.torch_tensor()
|
||||
output = ColoTensor.init_from_torch_tensor(output,
|
||||
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
|
||||
output = output + bias
|
||||
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
|
||||
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
|
||||
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
|
||||
if bias is not None:
|
||||
bias = bias.torch_tensor()
|
||||
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
|
||||
|
||||
output = ColoTensor.init_from_torch_tensor(
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(
|
||||
output_parallel,
|
||||
spec=TensorSpec(
|
||||
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
|
||||
spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
|
||||
if parallel_action.gather_out:
|
||||
# All-Gather(Output)
|
||||
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
|
||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
return output
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
def colo_linear(types, args, kwargs, pg):
|
||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
|
||||
assert mode in ('row', 'col')
|
||||
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
|
||||
return funcs[mode](input_tensor, weight, bias)
|
||||
|
||||
|
||||
@colo_op_impl(F.linear)
|
||||
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
This method computes a linear.
|
||||
"""
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if version.parse(torch.__version__) > version.parse("1.11.0"):
|
||||
if len(args) == 3:
|
||||
bias = args[2]
|
||||
else:
|
||||
bias = None
|
||||
else:
|
||||
bias = kwargs.get('bias', None)
|
||||
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
|
||||
|
||||
if not isinstance(weight, ColoTensor):
|
||||
weight = ColoTensor.init_from_torch_tensor(weight)
|
||||
|
||||
if bias is not None and not isinstance(bias, ColoTensor):
|
||||
bias = ColoTensor.init_from_torch_tensor(bias)
|
||||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
# building the computing graph, inputs -> op
|
||||
if GraphGlobalEnv().graph_building:
|
||||
cur_op_node = GraphOpNode('linear', [weight, bias])
|
||||
cur_op_node.add_prev_tensor(input_tensor)
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
weight = weight.torch_tensor()
|
||||
if bias is not None:
|
||||
bias = bias.torch_tensor()
|
||||
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||
assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op'
|
||||
assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
|
||||
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
|
||||
mode = 'row'
|
||||
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
|
||||
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# building the computing graph, op -> output
|
||||
if GraphGlobalEnv().graph_building:
|
||||
cur_op_node.add_post_tensor(ret_tensor)
|
||||
|
||||
return ret_tensor
|
||||
|
@@ -1,40 +1,37 @@
|
||||
from colossalai.tensor.dist_spec import DistPlacementPattern
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.cross_entropy)
|
||||
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||
arg_num = len(args)
|
||||
|
||||
if arg_num > 0:
|
||||
input_tensor = args[0]
|
||||
if arg_num > 1:
|
||||
target = args[1]
|
||||
if arg_num > 2:
|
||||
weight = args[2]
|
||||
|
||||
if 'input' in kwargs:
|
||||
input_tensor = kwargs.pop('input')
|
||||
if 'target' in kwargs:
|
||||
target = kwargs.pop('target')
|
||||
if 'weight' in kwargs:
|
||||
weight = kwargs.pop('weight')
|
||||
|
||||
if not isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
|
||||
if isinstance(target, ColoTensor):
|
||||
target = target.torch_tensor()
|
||||
@colo_op_impl(F.cross_entropy)
|
||||
def colo_cross_entropy(input_tensor: GeneralTensor,
|
||||
target: GeneralTensor,
|
||||
weight: Optional[GeneralTensor] = None,
|
||||
size_average: Optional[bool] = None,
|
||||
ignore_index: int = -100,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean",
|
||||
label_smoothing: float = 0.0):
|
||||
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
|
||||
|
||||
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||
return ColoTensor.init_from_torch_tensor(
|
||||
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
|
||||
output = F.cross_entropy(input_tensor,
|
||||
target,
|
||||
weight=weight,
|
||||
size_average=size_average,
|
||||
ignore_index=ignore_index,
|
||||
reduce=reduce,
|
||||
reduction=reduction,
|
||||
label_smoothing=label_smoothing)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
||||
if input_tensor.spec.is_1D_col():
|
||||
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
|
||||
target))
|
||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from .colo_tensor import ColoTensor
|
||||
from .const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor):
|
||||
@@ -8,21 +10,26 @@ class ColoParameter(ColoTensor):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kargs):
|
||||
super().__init__(*args, **kargs)
|
||||
self._type = TensorType.MODEL
|
||||
def __new__(cls,
|
||||
data: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
t = super(ColoParameter, cls).__new__(cls)
|
||||
t._type = TensorType.MODEL
|
||||
return t
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
self._type = TensorType.MODEL
|
||||
self._graph_node = None
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
|
||||
colo_p = ColoParameter(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.is_pinned(),
|
||||
device=tensor.device,
|
||||
torch_tensor=tensor if save_payload else torch.empty(0))
|
||||
return colo_p
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
|
@@ -1,16 +1,23 @@
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from copy import copy
|
||||
import torch
|
||||
from typing import Tuple, Optional, Callable, Union
|
||||
from numpy import product
|
||||
from colossalai.tensor import TensorSpec
|
||||
from .const import TensorType
|
||||
from colossalai.tensor import dist_spec
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.dist_spec import _DistSpec
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
def _convert_output(output):
|
||||
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
|
||||
output = ColoTensor.from_torch_tensor(output)
|
||||
elif isinstance(output, (list, tuple)):
|
||||
output = type(output)(_convert_output(o) for o in output)
|
||||
return output
|
||||
|
||||
|
||||
class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
2. It supports lazy init the tensor's payload.
|
||||
@@ -18,120 +25,23 @@ class ColoTensor(object):
|
||||
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
return super(ColoTensor, cls).__new__(cls)
|
||||
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
|
||||
def __init__(self,
|
||||
*size: Tuple[int],
|
||||
dtype=None,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
device=None,
|
||||
torch_tensor=torch.empty(0),
|
||||
spec: TensorSpec = TensorSpec(dist_spec.replicate())):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
self._requires_grad = requires_grad
|
||||
self._pin_memory = pin_memory
|
||||
self._device = device
|
||||
self._torch_tensor = torch_tensor
|
||||
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._spec = copy(spec)
|
||||
self._type = TensorType.NONMODEL
|
||||
self._graph_node = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
|
||||
|
||||
@property
|
||||
def spec(self) -> TensorSpec:
|
||||
return self._spec
|
||||
|
||||
@property
|
||||
def shard_pattern(self):
|
||||
return self._shard_pattern
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._torch_tensor.data
|
||||
|
||||
@data.setter
|
||||
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
|
||||
if isinstance(tensor, ColoTensor):
|
||||
self._torch_tensor.data = tensor.data
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
self._torch_tensor.data = tensor
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._torch_tensor.grad
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return torch.Size(self._size)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._torch_tensor.device
|
||||
|
||||
def size(self, dim=None):
|
||||
if dim is None:
|
||||
return self.shape
|
||||
return self._size[dim]
|
||||
|
||||
def dim(self):
|
||||
return len(self._size)
|
||||
|
||||
def normal_(self, mean=0., std=1.):
|
||||
torch_tensor = self.torch_tensor()
|
||||
return torch_tensor.normal_(mean=mean, std=std)
|
||||
|
||||
def numel(self):
|
||||
return product(self._size)
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor,
|
||||
save_payload=True,
|
||||
spec: TensorSpec = TensorSpec(dist_spec.replicate())) -> 'ColoTensor':
|
||||
colo_t = ColoTensor(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.is_pinned(),
|
||||
device=tensor.device,
|
||||
torch_tensor=tensor if save_payload else torch.empty(0),
|
||||
spec=spec)
|
||||
return colo_t
|
||||
|
||||
def del_torch_tensor(self, save_shape=False) -> None:
|
||||
"""
|
||||
delete the payload of the torch tensor.
|
||||
|
||||
Args:
|
||||
save_shape (bool, optional): if saving the shape of the torch_tensor.
|
||||
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
|
||||
Defaults to False.
|
||||
"""
|
||||
if not save_shape:
|
||||
self._size = (0,)
|
||||
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
|
||||
|
||||
def torch_tensor(self) -> torch.Tensor:
|
||||
if self._torch_tensor.numel() == 0:
|
||||
self._torch_tensor = torch.empty(*self._size,
|
||||
dtype=self._dtype,
|
||||
pin_memory=self._pin_memory,
|
||||
requires_grad=self._requires_grad,
|
||||
device=self._device)
|
||||
return self._torch_tensor
|
||||
|
||||
def set_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self.to_dist_spec(spec.dist_spec)
|
||||
self.convert_to_dist_spec_(spec.dist_spec)
|
||||
self._spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
@@ -142,89 +52,51 @@ class ColoTensor(object):
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if not all(issubclass(cls, t) for t in types):
|
||||
return NotImplemented
|
||||
global _COLOSSAL_OPS
|
||||
if func in _COLOSSAL_OPS:
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||
func = _COLOSSAL_OPS[func]
|
||||
|
||||
for kwarg in kwargs.values():
|
||||
if isinstance(kwarg, ColoTensor):
|
||||
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
||||
else:
|
||||
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
|
||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if func in get_default_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
return _convert_output(ret)
|
||||
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return cls._filter_outputs_with_colo(func(*args, **kwargs))
|
||||
def __repr__(self):
|
||||
return f'ColoTensor: {super().__repr__()}'
|
||||
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def __add__(self, o) -> "ColoTensor":
|
||||
if isinstance(o, ColoTensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||
elif isinstance(o, (torch.Tensor, int, float)):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
|
||||
else:
|
||||
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
|
||||
|
||||
__radd__ = __add__
|
||||
|
||||
def __truediv__(self, o) -> "ColoTensor":
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
def replace_tensor_with_colo(func):
|
||||
|
||||
def execute_func(*args, **kwargs):
|
||||
# transform the ColoTensor args to torch Tensor.
|
||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return self._filter_outputs_with_colo(func(*args, **kwargs))
|
||||
|
||||
return execute_func
|
||||
|
||||
if hasattr(self._torch_tensor, name) == False:
|
||||
raise AttributeError
|
||||
|
||||
attr = getattr(self._torch_tensor, name)
|
||||
|
||||
if isinstance(attr, Callable):
|
||||
return replace_tensor_with_colo(attr)
|
||||
else:
|
||||
return attr
|
||||
|
||||
@classmethod
|
||||
def _filter_outputs_with_colo(cls, outputs):
|
||||
if outputs is None: # return None
|
||||
return None
|
||||
elif type(outputs) is not tuple: # num of return val = 1
|
||||
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
|
||||
else: # num of return val > 1
|
||||
return tuple([
|
||||
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
||||
for output in outputs
|
||||
])
|
||||
|
||||
def __mul__(self, other) -> "ColoTensor":
|
||||
if isinstance(other, ColoTensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
|
||||
elif isinstance(other, (torch.Tensor, int, float)):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other)
|
||||
else:
|
||||
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
self._torch_tensor = DistSpecManager.handle_trans_spec(self.torch_tensor(), self.spec.dist_spec, dist_spec)
|
||||
if self._torch_tensor.is_leaf:
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
self._spec.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
spec = copy(self._spec)
|
||||
spec.dist_spec = dist_spec
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
return ColoTensor.from_torch_tensor(ret, spec)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
tensor.__init__(tensor, spec=spec)
|
||||
return tensor
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(self.spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from colossalai.tensor.dist_spec import _DistSpec
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from numpy import prod
|
||||
from contextlib import contextmanager
|
||||
@@ -53,7 +53,7 @@ class DistSpecManager:
|
||||
@staticmethod
|
||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
||||
and dist_spec.process_group is not None:
|
||||
and dist_spec.process_group is not None:
|
||||
raise NotImplementedError
|
||||
return tensor
|
||||
|
||||
@@ -66,7 +66,7 @@ class DistSpecManager:
|
||||
@staticmethod
|
||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
if old_dist_spec.process_group != dist_spec.process_group \
|
||||
and dist_spec.process_group is not None:
|
||||
and dist_spec.process_group is not None:
|
||||
raise NotImplementedError
|
||||
return DistSpecManager._gather(tensor, old_dist_spec)
|
||||
|
||||
|
@@ -9,11 +9,6 @@ _COLOSSAL_OPS: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def _register_colo_op(op, func):
|
||||
from inspect import signature
|
||||
if len(signature(func).parameters) != 4:
|
||||
raise TypeError(f'Custom stateful op function expects signature: '
|
||||
f'(types, args, kwargs, process_group), but received '
|
||||
f'signature: {signature(func)}')
|
||||
global _COLOSSAL_OPS
|
||||
_COLOSSAL_OPS[op] = func
|
||||
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
@@ -77,6 +78,9 @@ class TensorSpec(object):
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_process_group_size(self):
|
||||
return dist.get_world_size(self.dist_spec.process_group)
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
||||
|
Reference in New Issue
Block a user