[refactor] move process group from _DistSpec to ColoTensor. (#1203)

This commit is contained in:
Jiarui Fang
2022-07-06 16:15:16 +08:00
committed by GitHub
parent 5da87ce35d
commit ae7d3f4927
34 changed files with 452 additions and 367 deletions

View File

@@ -6,15 +6,15 @@ import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import divide
from colossalai.tensor import ProcessGroup
from colossalai.tensor import ProcessGroup, ColoTensorSpec
GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
def convert_to_colo_tensor(tensor: Optional[GeneralTensor], pg: ProcessGroup) -> Optional[ColoTensor]:
if tensor is not None and not isinstance(tensor, ColoTensor):
tensor = ColoTensor.from_torch_tensor(tensor)
tensor = ColoTensor.from_torch_tensor(tensor, ColoTensorSpec(pg))
return tensor

View File

@@ -1,7 +1,7 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec, ColoTensorSpec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad
@@ -11,7 +11,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.convert_to_dist_spec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]))
mat1 = mat1.convert_to_dist_spec(distspec.shard([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
@@ -20,20 +20,20 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.get_process_group())))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate()))
return output
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.tensor_spec.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.get_process_group()))
compute_spec = mat2.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate())
mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if compute_spec.output_replicate:
@@ -51,27 +51,29 @@ def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: C
@colo_op_impl(torch.addmm)
def colo_addmm(input_tensor: GeneralTensor,
mat1: GeneralTensor,
mat2: GeneralTensor,
*args,
mat1: ColoTensor,
mat2: ColoTensor,
beta: Number = 1,
alpha: Number = 1) -> ColoTensor:
alpha: Number = 1,
*args) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
# At least one of the tensor should be ColoTensor
assert isinstance(mat2, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group())
mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group())
# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.tensor_spec.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op'
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.tensor_spec.is_shard_1drow() and input_tensor.tensor_spec.is_replicate():
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():
mode = 'row'
elif mat2.tensor_spec.is_shard_1dcol() and (input_tensor.tensor_spec.is_shard_1dcol()
or input_tensor.tensor_spec.is_shard_1drow()):
elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()):
mode = 'col'
else:
raise NotImplementedError

View File

@@ -1,9 +1,8 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from copy import copy
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, ColoTensorSpec
from ._utils import GeneralTensor
@@ -16,11 +15,16 @@ def register_elementwise_op(op):
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor):
spec = copy(input_tensor.tensor_spec)
return ColoTensor.from_torch_tensor(output, spec=spec)
return ColoTensor.from_torch_tensor(output)
if not isinstance(output, torch.Tensor):
raise NotImplementedError
return ColoTensor.from_torch_tensor(output,
spec=ColoTensorSpec(input_tensor.process_group,
dist_attr=input_tensor.dist_spec,
compute_attr=input_tensor.compute_spec))
# Tensor op

View File

@@ -1,7 +1,7 @@
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
@@ -14,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
output_parallel = F.embedding(input_tensor,
weight,
@@ -23,11 +23,11 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
compute_spec = weight.tensor_spec.compute_spec
compute_spec = weight.compute_spec
if compute_spec.output_replicate:
return output.to_replicate()
@@ -45,10 +45,11 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.tensor_spec.dist_spec.process_group.tp_local_rank()
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
num_embeddings_per_partition = weight.size_local(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition
@@ -73,7 +74,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate()))
return output
@@ -107,12 +108,11 @@ def colo_embedding(input_tensor: GeneralTensor,
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method looks up an embedding table.
"""
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
# Handle differen parallel actions.
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native embedding op'
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
@@ -121,10 +121,10 @@ def colo_embedding(input_tensor: GeneralTensor,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_shard_1drow():
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1drow():
mode = 'row'
elif weight.tensor_spec.is_shard_1dcol():
elif weight.is_shard_1dcol():
mode = 'col'
else:
raise NotImplementedError

View File

@@ -2,7 +2,7 @@ import torch.nn.functional as F
from typing import Optional
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
@@ -19,7 +19,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
padding_idx: Optional[int] = None) -> ColoTensor:
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
output_parallel = F.embedding_bag(input_tensor,
weight,
@@ -32,11 +33,11 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.tensor_spec.compute_spec.output_replicate:
if weight.compute_spec.output_replicate:
return output.to_replicate()
else:
return output
@@ -84,12 +85,13 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
This method looks up an embedding table.
"""
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
# Handle differen parallel actions.
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native embedding op'
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding_bag(input_tensor,
weight,
@@ -102,8 +104,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx))
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_shard_1dcol():
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol():
tp_mode = 'col'
else:
raise NotImplementedError

View File

@@ -1,8 +1,7 @@
import torch
import torch.nn.functional as F
from typing import List, Optional
import torch.nn.functional as F
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, distspec
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
@@ -14,11 +13,11 @@ def colo_layernorm(
bias: Optional[GeneralTensor] = None,
eps: float = 1e-5,
):
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.get_process_group()))
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
return output

View File

@@ -3,7 +3,7 @@ from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
@@ -11,8 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
input_tensor = input_tensor.convert_to_dist_spec(
distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]))
input_tensor = input_tensor.convert_to_dist_spec(distspec.shard([-1], [weight.get_tp_world_size()]))
# Output:P
partial_output = F.linear(input_tensor, weight)
@@ -23,7 +22,8 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
pg = input_tensor.get_process_group()
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
return output
@@ -31,16 +31,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
compute_spec = weight.tensor_spec.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, weight.tensor_spec.dist_spec.process_group)
compute_spec = weight.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel,
spec=TensorSpec(
distspec.shard(weight.get_process_group(), [-1],
[weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)))
spec=ColoTensorSpec(weight.get_process_group(),
distspec.shard([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate:
return output.to_replicate()
else:
@@ -53,29 +52,32 @@ def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias
return funcs[mode](input_tensor, weight, bias)
@register_colo_graph(input_pos=[1], param_pos=[2, 3])
# @register_colo_graph(input_pos=[1], param_pos=[2, 3])
def colo_linear_imp(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
assert isinstance(weight, ColoTensor)
pg = weight.get_process_group()
input_tensor = convert_to_colo_tensor(input_tensor, pg)
bias = convert_to_colo_tensor(bias, pg)
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op'
assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_shard_1dcol() and (bias is None or bias.tensor_spec.is_replicate()):
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
mode = 'row'
elif weight.tensor_spec.is_shard_1drow() and (bias is None or bias.tensor_spec.is_shard_1drow()
or bias.tensor_spec.is_shard_1dcol()):
elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()):
mode = 'col'
else:
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}")
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}")
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
else:
raise NotImplementedError

View File

@@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
from ._utils import GeneralTensor, convert_to_colo_tensor
@@ -16,9 +16,13 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0):
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor)
pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor)
weight = convert_to_colo_tensor(weight, pg)
target = convert_to_colo_tensor(target, pg)
input_tensor = convert_to_colo_tensor(input_tensor, pg)
if input_tensor.tensor_spec.is_replicate(): # Input is gathered
if input_tensor.is_replicate(): # Input is gathered
output = F.cross_entropy(input_tensor,
target,
weight=weight,
@@ -27,11 +31,11 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output)
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate()
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
if input_tensor.tensor_spec.is_shard_1dcol():
if input_tensor.is_shard_1dcol():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output)
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
else:
raise NotImplementedError
else:

View File

@@ -23,6 +23,7 @@ def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
def wrapper(*args, **kwargs):
param_list = []
input_list = []
# TODO(jiaruifang) find the pg
for idx, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and idx in input_pos:
input_list.append(convert_to_colo_tensor(arg))

View File

@@ -21,7 +21,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
'weight': distspec.shard([0], [pg.tp_world_size()]),
},
mode='row',
)
@@ -30,7 +30,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
'weight': distspec.shard([-1], [pg.tp_world_size()]),
},
mode='col',
)

View File

@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
'weight': distspec.shard([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()])
},
mode='col',
)

View File

@@ -1,5 +1,6 @@
from typing import Dict
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup
from colossalai.tensor import distspec
from . import ColoModule
import torch
@@ -39,7 +40,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
if param.has_compute_spec():
cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
cur_compute_pattern = param.compute_spec.compute_pattern
if compute_pattern is None:
compute_pattern = cur_compute_pattern
else:
@@ -62,7 +63,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
for param_name, dist_spec in param_specs.items():
param = module.get_parameter(param_name)
if param.has_compute_spec():
if dist_spec != param.tensor_spec.dist_spec:
if dist_spec != param.dist_spec:
cur_match = False
break
else:
@@ -100,8 +101,8 @@ def init_colo_module(module: torch.nn.Module,
continue
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
spec = TensorSpec(dist_spec, compute_spec)
param.set_tensor_spec(spec)
param.set_dist_spec(dist_spec)
param.compute_spec = compute_spec
for mod in param.shared_param_modules:
modules_update_param.add(mod)
for mod in modules_update_param:

View File

@@ -1,5 +1,5 @@
from .process_group import ProcessGroup
from .tensor_spec import TensorSpec
from .tensor_spec import ColoTensorSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
@@ -9,7 +9,7 @@ from .param_op_hook import ParamOpHook, ParamOpHookManager
from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
'ProcessGroup'
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec'
]

View File

@@ -5,7 +5,7 @@ from copy import copy
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor import TensorSpec, distspec
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
@@ -28,7 +28,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __new__(cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
spec: ColoTensorSpec = None) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -36,11 +36,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._tensor_spec = copy(spec)
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
self._graph_node = None
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@@ -51,7 +49,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
@@ -82,7 +80,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
tensor = ColoParameter(data,
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
memo[id(self)] = tensor
return tensor

View File

@@ -4,18 +4,18 @@ from copy import copy
import torch
from torch.overrides import get_default_nowrap_functions
from colossalai.tensor import TensorSpec
from colossalai.tensor import distspec
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional
def _convert_output(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output = ColoTensor.from_torch_tensor(output)
def _check_output(output):
if not isinstance(output, torch.Tensor):
raise RuntimeError
elif isinstance(output, (list, tuple)):
output = type(output)(_convert_output(o) for o in output)
output = type(output)(_check_output(o) for o in output)
return output
@@ -23,28 +23,29 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
"""
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
@@ -52,37 +53,72 @@ class ColoTensor(torch.Tensor):
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._tensor_spec = copy(spec)
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if not spec:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
self.process_group = spec.pg
self._type = TensorType.NONMODEL
self._graph_node = None
@property
def tensor_spec(self) -> TensorSpec:
return self._tensor_spec
@tensor_spec.setter
def tensor_spec(self, tenseor_spec: TensorSpec):
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def set_tensor_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def has_compute_spec(self) -> bool:
return self._tensor_spec.compute_spec is not None
return self.compute_spec is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self._tensor_spec.dist_spec.process_group
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
Args:
pg (ProcessGroup): target pg
Raises:
RuntimeError:
RuntimeError:
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
if self.process_group.tp_world_size() != 1:
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
if self.dist_spec.placement.value != 'r':
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
self.process_group = pg
def get_tp_world_size(self) -> int:
return self._tensor_spec.dist_spec.process_group.tp_world_size()
return self.process_group.tp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
self._convert_to_dist_spec(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):
return self.compute_spec.compute_pattern == compute_pattern
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -100,7 +136,9 @@ class ColoTensor(torch.Tensor):
if func in get_default_nowrap_functions():
return ret
else:
return _convert_output(ret)
# TODO(jiaruifang) its parallel Op's duty to convert output activations
return ret
# return _check_output(ret)
def __repr__(self):
return f'ColoTensor: {super().__repr__()}'
@@ -113,30 +151,28 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): the target dist. spec.
"""
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
self._tensor_spec.dist_spec = dist_spec
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
tensor_spec = copy(self._tensor_spec)
tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, tensor_spec)
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
self.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
return self.convert_to_dist_spec(distspec.replicate())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
@@ -147,7 +183,7 @@ class ColoTensor(torch.Tensor):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
memo[id(self)] = tensor
return tensor
@@ -165,12 +201,13 @@ class ColoTensor(torch.Tensor):
Returns:
ColoTensor: a tensor after viewed.
"""
if self.tensor_spec.is_replicate():
if self.is_replicate():
return super().view(*args)
# TODO(jiaruifang) check why this not work
# self.data = self.to_replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
self.process_group)
self.dist_spec = distspec.replicate()
return super().view(*args)
def size_global(self, args: Optional[int] = None):
@@ -179,13 +216,13 @@ class ColoTensor(torch.Tensor):
Returns:
ColoTensor: a tensor after viewed.
"""
if self.tensor_spec.is_replicate():
if self.is_replicate():
if args is not None:
return super().size(args)
else:
return super().size()
spec = self.tensor_spec.dist_spec
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
@@ -198,3 +235,19 @@ class ColoTensor(torch.Tensor):
return size_list[args]
else:
return torch.Size(size_list)
# Some API for dist spec check
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0

View File

@@ -6,6 +6,7 @@ import torch
import torch.distributed as dist
from packaging import version
from colossalai.logging import get_dist_logger
from colossalai.tensor import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
@@ -29,15 +30,17 @@ def divide(numerator, denominator):
class TransformDistSpec(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, old_dist_spec, dist_spec, forward_trans_func, backward_trans_func):
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
ctx.old_dist_spec = old_dist_spec
ctx.dist_spec = dist_spec
ctx.backward_trans_func = backward_trans_func
return forward_trans_func(tensor, old_dist_spec, dist_spec)
ctx.pg = pg
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def backward(ctx, grad_outputs):
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec), None, None, None, None
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
ctx.pg), None, None, None, None, None
class DistSpecManager:
@@ -46,18 +49,17 @@ class DistSpecManager:
@staticmethod
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
raise NotImplementedError
pass
@staticmethod
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
"""_shard_as: shard the tensor w.r.t a distributed specification.
Assuming the tensor passed in is a global (replicated) tensor.
Args:
tensor (torch.Tensor): a global (replicated) tensor before shard
dist_spec (_DistSpec): the distributed spec. to be sharded as.
pg (ProcessGrouo): the process group of the corresponding colotensor
Returns:
torch.Tensor: a torch tensor after sharded.
"""
@@ -65,7 +67,7 @@ class DistSpecManager:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = dist_spec.process_group.tp_local_rank()
idx = pg.tp_local_rank()
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
@@ -76,7 +78,7 @@ class DistSpecManager:
return chunk.clone().detach().contiguous()
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
"""_gather gather sharded tensors to a replicated one.
Args:
tensor (torch.Tensor): a shared torch tensor
@@ -92,9 +94,9 @@ class DistSpecManager:
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == 'cuda'
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
@@ -109,12 +111,14 @@ class DistSpecManager:
return buffer[0]
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
world_size = old_dist_spec.process_group.tp_world_size()
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
assert tensor.device.type == "cuda", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device"
gather_dim = old_dist_spec.dims[0]
@@ -126,46 +130,50 @@ class DistSpecManager:
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
return output_
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return tensor
@staticmethod
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec, pg)
@staticmethod
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
if old_dist_spec == dist_spec:
return tensor
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
# use all-to-all to save memory
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec)
tensor = DistSpecManager._gather(tensor, old_dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
if not DistSpecManager._use_autograd_function:
return forward_trans_handle(tensor, old_dist_spec, dist_spec)
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
backward_trans_handle = getattr(DistSpecManager,
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}')
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, forward_trans_handle, backward_trans_handle)
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
backward_trans_handle)
@staticmethod
@contextmanager

View File

@@ -1,7 +1,5 @@
from enum import Enum
from colossalai.tensor import ProcessGroup
from typing import Optional, List
from numpy import prod
from typing import List
__all__ = ['replicate', 'shard']
@@ -13,10 +11,7 @@ class DistPlacementPattern(Enum):
class _DistSpec:
def __init__(self,
dist_placement_pattern: DistPlacementPattern,
process_group: Optional[ProcessGroup] = None,
**meta_info):
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
"""_DistSpec, Distributed Specification
Args:
@@ -25,7 +20,6 @@ class _DistSpec:
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern
self.process_group = process_group
for k, v in meta_info.items():
setattr(self, k, v)
@@ -45,14 +39,11 @@ class _DistSpec:
return res
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
# process_group=None means global process group
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)
def replicate() -> _DistSpec:
return _DistSpec(DistPlacementPattern.REPLICATE)
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
assert process_group is not None and isinstance(process_group, ProcessGroup)
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@@ -3,6 +3,7 @@ from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec
class ParamOpHook(ABC):
@@ -129,7 +130,7 @@ def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.tensor_spec))
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
else:
info.append(None)
return info

View File

@@ -20,6 +20,9 @@ class ProcessGroup:
ranks: Optional[List[int]] = None,
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
if not torch.distributed.is_initialized():
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if rank is None:
self._rank = torch.distributed.get_rank()

View File

@@ -1,44 +1,12 @@
import torch.distributed as dist
from typing import Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from .compute_spec import ComputeSpec, ComputePattern
from .compute_spec import ComputeSpec
from colossalai.tensor import ProcessGroup
from dataclasses import dataclass
class TensorSpec(object):
"""
The specification of the ColoTensor.
Args:
dist_spec (_DistSpec): descriping the layout among processes.
compute_spec (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
Defaults to None.
"""
def __init__(self, dist_spec: _DistSpec, compute_spec: Optional[ComputeSpec] = None):
self.compute_spec = compute_spec
self.dist_spec = dist_spec
def get_process_group(self):
return self.dist_spec.process_group
def get_placement(self):
return self.dist_spec.placement
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.compute_spec.compute_pattern == compute_pattern
def __repr__(self):
return f'parallel action: {self.compute_spec}, dist_spec: {self.dist_spec}'
@dataclass
class ColoTensorSpec:
pg: ProcessGroup
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
compute_attr: Optional[ComputeSpec] = None

View File

@@ -1,6 +1,6 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, TensorSpec
from colossalai.tensor import ColoTensor, ColoParameter, distspec
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
@@ -36,16 +36,17 @@ def ColoModulize(module):
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
# build param to spec mapping
mapping = dict()
mapping1 = dict()
mapping2 = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_compute_spec():
has_dist_parameter = True
mapping[id(param)] = copy(param.tensor_spec)
param.set_tensor_spec(TensorSpec(distspec.replicate()))
mapping1[id(param)] = copy(param.dist_spec)
mapping2[id(param)] = copy(param.compute_spec)
param.set_dist_spec(distspec.replicate())
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
@@ -60,9 +61,10 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
with torch.no_grad():
for param in self.parameters():
param_id = id(param)
if param_id in mapping:
spec = mapping[id(param)]
param.set_tensor_spec(spec)
if param_id in mapping1:
dist_spec = mapping1[id(param)]
compute_spec = mapping2[id(param)]
param.set_tensor_spec(dist_spec, compute_spec)
return ret
@@ -122,7 +124,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# TODO(jiaruifang) we initialize a Default PG memory
colo_param = ColoParameter(param.to(self._device), requires_grad=requires_grad)
# add mapping record
replaced_tensors[param] = colo_param