mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
import torch
|
||||
from typing import Union, Optional
|
||||
from colossalai.tensor import ColoTensor
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||||
Number = Union[int, float]
|
||||
@@ -10,3 +16,182 @@ def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTens
|
||||
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def set_parallel_input(input_parallel: bool):
|
||||
env.parallel_input_1d = input_parallel
|
||||
|
||||
|
||||
def get_parallel_input():
|
||||
return env.parallel_input_1d
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
|
||||
|
||||
|
||||
def _reduce(input_, pg: ProcessGroup):
|
||||
# skip if only one rank involved
|
||||
if pg.tp_world_size() == 1:
|
||||
return input_
|
||||
assert input_.device.type == 'cuda'
|
||||
group = pg.tp_process_group()
|
||||
dist.all_reduce(input_, group=group)
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_, pg: ProcessGroup, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = pg.tp_local_rank()
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, pg: ProcessGroup, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = pg.tp_local_rank()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
assert input_.device.type == 'cuda'
|
||||
group = pg.tp_process_group()
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
"""
|
||||
Pass the input to the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
ctx.mode = process_group
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.mode), None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""
|
||||
All-reduce the input from the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
return _reduce(input_, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
ctx.mode = process_group
|
||||
ctx.dim = dim
|
||||
return _split(input_, process_group, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
ctx.mode = process_group
|
||||
ctx.dim = dim
|
||||
return _gather(input_, process_group, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def reduce_grad(input_, process_group):
|
||||
return _ReduceGrad.apply(input_, process_group)
|
||||
|
||||
|
||||
def reduce_input(input_, process_group):
|
||||
return _ReduceInput.apply(input_, process_group)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, process_group, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, process_group, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim)
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
|
||||
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
@@ -12,18 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1 = mat1.convert_to_dist_spec(
|
||||
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]))
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = reduce_input(partial_output, mat1.get_process_group())
|
||||
# input
|
||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor + alpha * output
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
@@ -31,13 +28,12 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
compute_spec = mat2.tensor_spec.compute_spec
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if compute_spec.output_replicate:
|
||||
|
@@ -1,11 +1,8 @@
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
||||
|
||||
|
||||
def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
@@ -17,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
@@ -26,9 +23,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
compute_spec = weight.tensor_spec.compute_spec
|
||||
@@ -49,9 +45,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = weight.tensor_spec.dist_spec.process_group.tp_local_rank()
|
||||
num_embeddings_per_partition = weight.size_local(0)
|
||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
@@ -75,9 +72,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
# Mask the output embedding.
|
||||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
|
||||
output = reduce_input(partial_output, weight.get_process_group())
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
|
@@ -32,9 +32,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if weight.tensor_spec.compute_spec.output_replicate:
|
||||
|
@@ -17,7 +17,7 @@ def colo_layernorm(
|
||||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
# TODO (ver217): check dist spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.get_process_group()))
|
||||
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
|
||||
|
@@ -2,9 +2,8 @@ import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
||||
|
||||
|
||||
@@ -13,19 +12,18 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
input_tensor = input_tensor.convert_to_dist_spec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]))
|
||||
distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = reduce_input(partial_output, weight.get_process_group())
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias
|
||||
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
@@ -35,13 +33,13 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||
# Input:B
|
||||
compute_spec = weight.tensor_spec.compute_spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = reduce_grad(input_tensor, weight.tensor_spec.dist_spec.process_group)
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||
spec=TensorSpec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1],
|
||||
[weight.tensor_spec.get_process_group_size()]),
|
||||
distspec.shard(weight.get_process_group(), [-1],
|
||||
[weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D)))
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
|
@@ -1,8 +1,6 @@
|
||||
import torch
|
||||
import itertools
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.gemini.chunk import TensorState, Chunk
|
||||
@@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
@@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
|
||||
>>> from colossalai.core import global_context as gpc
|
||||
>>> from colossalai.context import ParallelMode
|
||||
>>> model = torch.nn.Linear(20, 1)
|
||||
>>> model = ColoDDP(model)
|
||||
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
|
||||
>>> pg = ProcessGroup(tp_degree = world_size//2)
|
||||
>>> model = ColoDDP(model, pg)
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
@@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
|
||||
module (torch.nn.Module): Module to apply DDP.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||
If it's None, the default data parallel group will be used. Defaults to None.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
process_group: ColoProcessGroup,
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
@@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
||||
assert process_group
|
||||
|
||||
self.process_group = process_group.dp_process_group()
|
||||
self.dp_world_size = self.process_group.size()
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
@@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
|
||||
return empty_grad
|
||||
|
||||
else:
|
||||
#TODO(jiaruifang) fixme
|
||||
raise NotImplementedError
|
||||
dist.all_reduce(grad, group=self.cpu_process_group)
|
||||
return grad
|
||||
|
||||
@@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module.half())
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
process_group: Optional[ColoProcessGroup] = None) -> None:
|
||||
super().__init__(module.half(), process_group=process_group)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
|
@@ -52,5 +52,5 @@ class ColoModule(object):
|
||||
def get_param_names(self):
|
||||
return self._shard_params
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg):
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
@@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
|
||||
super(ColoEmbedding, self).__init__()
|
||||
self._register_shard_params(['weight'])
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D()
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self):
|
||||
def _set_TP1D(self, pg: ProcessGroup):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,7 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
@@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
|
||||
super(ColoLinear, self).__init__()
|
||||
self._register_shard_params(['weight', 'bias'])
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D()
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self):
|
||||
def _set_TP1D(self, pg):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias':
|
||||
None
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
|
||||
from . import ColoModule
|
||||
import torch
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
|
||||
return None
|
||||
|
||||
|
||||
def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
|
||||
if is_colo_module(module):
|
||||
colo_module = get_colo_module(module)
|
||||
param_names = colo_module.get_param_names()
|
||||
@@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
continue
|
||||
|
||||
if compute_pattern is not None:
|
||||
colo_module.register(compute_pattern)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern(compute_pattern):
|
||||
raise Exception(
|
||||
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||
@@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
check_colo_module(submodule, recursive=True)
|
||||
check_colo_module(submodule, pg=pg, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
|
||||
def init_colo_module(module: torch.nn.Module,
|
||||
compute_spec: ComputeSpec,
|
||||
pg: ProcessGroup,
|
||||
recursive=True,
|
||||
mode='default'):
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
raise NotImplementedError
|
||||
# a set for modules which update at least one param in the init process.
|
||||
@@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
check_colo_module(mod, recursive=False)
|
||||
check_colo_module(mod, pg, recursive=False)
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)
|
||||
init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
Reference in New Issue
Block a user