mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[tensor] a shorter shard and replicate spec (#1245)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from torch.fx.node import map_arg
|
||||
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
|
||||
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
|
||||
|
||||
|
||||
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
||||
@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
||||
setattr(weight, "fx_attr", spec)
|
||||
return weight
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
|
||||
from colossalai.tensor import distspec, ColoTensorSpec
|
||||
from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
|
||||
@@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
|
||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
||||
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
@@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||
# input
|
||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor + alpha * output
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate()))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
|
||||
return output
|
||||
|
||||
|
||||
@@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
compute_spec = mat2.compute_spec
|
||||
mat1 = mat1.redistribute(distspec.replicate())
|
||||
mat1 = mat1.redistribute(ReplicaSpec())
|
||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]),
|
||||
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
@@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]),
|
||||
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
@@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
||||
@@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, weight.get_process_group())
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate()))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
|
||||
return output
|
||||
|
||||
|
||||
|
@@ -2,7 +2,7 @@ import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from torch import Tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
|
||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output_parallel = F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
@@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if weight.compute_spec.output_replicate:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
import torch.nn.functional as F
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec
|
||||
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ def colo_layernorm(
|
||||
assert isinstance(weight, ColoTensor)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
|
||||
|
@@ -3,8 +3,7 @@ from typing import Optional
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
|
||||
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
|
||||
|
||||
|
||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
@@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
|
||||
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
@@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias
|
||||
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
|
||||
return output
|
||||
|
||||
|
||||
@@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
compute_spec = weight.compute_spec
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||
spec=ColoTensorSpec(weight.get_process_group(),
|
||||
distspec.shard([-1], [weight.get_tp_world_size()]),
|
||||
ShardSpec([-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D)))
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
|
@@ -7,16 +7,6 @@ class ColoModule(object):
|
||||
|
||||
def __init__(self):
|
||||
self._shard_params: List[str] = []
|
||||
# Example:
|
||||
# {ComputePattern.TP1D:
|
||||
# 'default':
|
||||
# 'weight':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'bias':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'row': ...
|
||||
# 'col': ...
|
||||
# }
|
||||
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
||||
|
||||
def _register_shard_params(self, params: List[str]):
|
||||
|
@@ -1,7 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
|
||||
|
||||
|
||||
class ColoEmbedding(ColoModule):
|
||||
@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
||||
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard([0], [pg.tp_world_size()])
|
||||
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||
'bias': ShardSpec([0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,5 +1,8 @@
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .distspec import shard as ShardSpec
|
||||
from .distspec import replicate as ReplicaSpec
|
||||
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
from .colo_parameter import ColoParameter
|
||||
@@ -11,5 +14,5 @@ from . import distspec
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
|
||||
'ColoTensorSpec', 'TensorSpec'
|
||||
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
|
||||
]
|
||||
|
@@ -5,7 +5,7 @@ import torch
|
||||
from functools import lru_cache
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
from colossalai.tensor import ProcessGroup, ReplicaSpec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional, Set, Callable
|
||||
@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
|
||||
# If not set spec, use a DP process group and replicate dist spec
|
||||
if spec is None:
|
||||
self.has_initialized = False
|
||||
self.dist_spec = distspec.replicate()
|
||||
self.dist_spec = ReplicaSpec()
|
||||
self.compute_spec = None
|
||||
self.process_group = ProcessGroup()
|
||||
else:
|
||||
@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._redistribute(dist_spec=distspec.replicate())
|
||||
self._redistribute(dist_spec=ReplicaSpec())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.redistribute(distspec.replicate())
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
replicated_t = self.redistribute(dist_spec=distspec.replicate())
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
|
||||
|
||||
from colossalai.nn.parallel.layers import register_colo_module, \
|
||||
ColoLinear, ColoEmbedding
|
||||
|
Reference in New Issue
Block a user