[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang 2022-07-04 18:54:37 +08:00 committed by GitHub
parent abf6a262dc
commit 060b917daf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 499 additions and 357 deletions

View File

@ -1,6 +1,12 @@
import torch import torch
from typing import Union, Optional from typing import Union, Optional
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
import torch
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import divide
from colossalai.tensor import ProcessGroup
GeneralTensor = Union[ColoTensor, torch.Tensor] GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float] Number = Union[int, float]
@ -10,3 +16,182 @@ def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTens
if tensor is not None and not isinstance(tensor, ColoTensor): if tensor is not None and not isinstance(tensor, ColoTensor):
tensor = ColoTensor.from_torch_tensor(tensor) tensor = ColoTensor.from_torch_tensor(tensor)
return tensor return tensor
def set_parallel_input(input_parallel: bool):
env.parallel_input_1d = input_parallel
def get_parallel_input():
return env.parallel_input_1d
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
def _reduce(input_, pg: ProcessGroup):
# skip if only one rank involved
if pg.tp_world_size() == 1:
return input_
assert input_.device.type == 'cuda'
group = pg.tp_process_group()
dist.all_reduce(input_, group=group)
return input_
def _split(input_, pg: ProcessGroup, dim=-1):
# skip if only one rank involved
world_size = pg.tp_world_size()
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = pg.tp_local_rank()
output = tensor_list[rank].contiguous()
return output
def _gather(input_, pg: ProcessGroup, dim=-1):
# skip if only one rank involved
world_size = pg.tp_world_size()
if world_size == 1:
return input_
# all gather
rank = pg.tp_local_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
assert input_.device.type == 'cuda'
group = pg.tp_process_group()
torch.distributed.all_gather(tensor_list, input_, group=group)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""
Pass the input to the model parallel region.
Args:
input_: input matrix.
process_group: parallel mode.
"""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, process_group):
ctx.mode = process_group
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.mode), None
class _ReduceInput(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
process_group: parallel mode.
"""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_, process_group):
return _reduce(input_, process_group)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.mode = process_group
ctx.dim = dim
return _split(input_, process_group, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.mode = process_group
ctx.dim = dim
return _gather(input_, process_group, dim)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.mode, ctx.dim), None, None
def reduce_grad(input_, process_group):
return _ReduceGrad.apply(input_, process_group)
def reduce_input(input_, process_group):
return _ReduceInput.apply(input_, process_group)
def split_forward_gather_backward(input_, process_group, dim):
return _SplitForwardGatherBackward.apply(input_, process_group, dim)
def gather_forward_split_backward(input_, process_group, dim):
return _GatherForwardSplitBackward.apply(input_, process_group, dim)

View File

@ -1,10 +1,9 @@
import torch import torch
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, Number, convert_to_colo_tensor from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
@ -12,18 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P # mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res # beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.convert_to_dist_spec( mat1 = mat1.convert_to_dist_spec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]))
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]))
# Output:P # Output:P
partial_output = torch.mm(mat1, mat2) partial_output = torch.mm(mat1, mat2)
# Reduce(Output) # Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, mat1.get_process_group())
# input # input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.get_process_group())))
spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group())))
return output return output
@ -31,13 +28,12 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.tensor_spec.compute_spec compute_spec = mat2.tensor_spec.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group())) mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.get_process_group()))
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D) mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec( output_spec = TensorSpec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]),
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if compute_spec.output_replicate: if compute_spec.output_replicate:

View File

@ -1,11 +1,8 @@
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.context import ParallelMode from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_embedding_1Dcol(input_tensor: ColoTensor, def colo_embedding_1Dcol(input_tensor: ColoTensor,
@ -17,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor: sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
output_parallel = F.embedding(input_tensor, output_parallel = F.embedding(input_tensor,
weight, weight,
@ -26,9 +23,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type=norm_type, norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse) sparse=sparse)
output_spec = TensorSpec( output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
compute_spec = weight.tensor_spec.compute_spec compute_spec = weight.tensor_spec.compute_spec
@ -49,9 +45,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here # Find index in this shard and mask those not here
# Reduce all # Reduce all
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.tensor_spec.dist_spec.process_group.tp_local_rank()
num_embeddings_per_partition = weight.size_local(0) num_embeddings_per_partition = weight.size_local(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition vocab_end_index = vocab_start_index + num_embeddings_per_partition
@ -75,9 +72,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Mask the output embedding. # Mask the output embedding.
partial_output[input_mask, :] = 0. partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
return output return output

View File

@ -32,9 +32,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights=per_sample_weights, per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset, include_last_offset=include_last_offset,
padding_idx=padding_idx) padding_idx=padding_idx)
output_spec = TensorSpec( output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.tensor_spec.compute_spec.output_replicate: if weight.tensor_spec.compute_spec.output_replicate:

View File

@ -17,7 +17,7 @@ def colo_layernorm(
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# TODO (ver217): check dist spec # TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.get_process_group()))
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec) output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)

View File

@ -2,9 +2,8 @@ import torch.nn.functional as F
from typing import Optional from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.context import ParallelMode
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
@ -13,19 +12,18 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
input_tensor = input_tensor.convert_to_dist_spec( input_tensor = input_tensor.convert_to_dist_spec(
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()])) distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]))
# Output:P # Output:P
partial_output = F.linear(input_tensor, weight) partial_output = F.linear(input_tensor, weight)
# Reduce(Output) # Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, weight.get_process_group())
# Bias # Bias
if bias is not None: if bias is not None:
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias output = output + bias
output = ColoTensor.from_torch_tensor(output, output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
return output return output
@ -35,13 +33,13 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B # Input:B
compute_spec = weight.tensor_spec.compute_spec compute_spec = weight.tensor_spec.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D) input_parallel = reduce_grad(input_tensor, weight.tensor_spec.dist_spec.process_group)
output_parallel = F.linear(input_parallel, weight, bias) output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel, output = ColoTensor.from_torch_tensor(output_parallel,
spec=TensorSpec( spec=TensorSpec(
distspec.shard(weight.tensor_spec.get_process_group(), [-1], distspec.shard(weight.get_process_group(), [-1],
[weight.tensor_spec.get_process_group_size()]), [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))) ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate: if compute_spec.output_replicate:
return output.to_replicate() return output.to_replicate()

View File

@ -1,8 +1,6 @@
import torch import torch
import itertools import itertools
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.gemini.chunk import TensorState, Chunk from colossalai.gemini.chunk import TensorState, Chunk
@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer from .reducer import Reducer
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
>>> from colossalai.core import global_context as gpc >>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode >>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1) >>> model = torch.nn.Linear(20, 1)
>>> model = ColoDDP(model) >>> pg = ProcessGroup(tp_degree = world_size//2)
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA)) >>> model = ColoDDP(model, pg)
>>> logits = model(x) >>> logits = model(x)
>>> loss = criterion(logits, labels) >>> loss = criterion(logits, labels)
>>> model.backward(loss) >>> model.backward(loss)
@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP. module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses. process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None. If it's None, the default data parallel group will be used. Defaults to None.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU. cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None. If it's None, the default CPU data parallel group will be used. Defaults to None.
""" """
def __init__(self, def __init__(self,
module: torch.nn.Module, module: torch.nn.Module,
process_group: Optional[dist.ProcessGroup] = None, process_group: ColoProcessGroup,
cpu_process_group: Optional[dist.ProcessGroup] = None, cpu_process_group: Optional[dist.ProcessGroup] = None,
bucket_cap_mb: int = 25, bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None: rebuild_bucket: bool = True) -> None:
@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
super().__init__() super().__init__()
self.module = module self.module = module
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.process_group = process_group or gpc.get_group(ParallelMode.DATA) assert process_group
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
self.process_group = process_group.dp_process_group()
self.dp_world_size = self.process_group.size() self.dp_world_size = self.process_group.size()
self.reducer = Reducer(bucket_cap_mb) self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket self.rebuild_bucket = rebuild_bucket
@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
return empty_grad return empty_grad
else: else:
#TODO(jiaruifang) fixme
raise NotImplementedError
dist.all_reduce(grad, group=self.cpu_process_group) dist.all_reduce(grad, group=self.cpu_process_group)
return grad return grad
@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``. For more details, see the API reference of ``GeminiManager``.
""" """
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: def __init__(self,
super().__init__(module.half()) module: torch.nn.Module,
gemini_manager: GeminiManager,
process_group: Optional[ColoProcessGroup] = None) -> None:
super().__init__(module.half(), process_group=process_group)
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager) self.param_op_hook = ZeROHookV2(gemini_manager)

View File

@ -52,5 +52,5 @@ class ColoModule(object):
def get_param_names(self): def get_param_names(self):
return self._shard_params return self._shard_params
def register(self, compute_pattern): def register(self, compute_pattern, pg):
raise NotImplementedError raise NotImplementedError

View File

@ -1,5 +1,5 @@
from .colo_module import ColoModule from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
super(ColoEmbedding, self).__init__() super(ColoEmbedding, self).__init__()
self._register_shard_params(['weight']) self._register_shard_params(['weight'])
def register(self, compute_pattern): def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns: if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern: if ComputePattern.TP1D == compute_pattern:
self._set_TP1D() self._set_TP1D(pg)
def _set_TP1D(self): def _set_TP1D(self, pg: ProcessGroup):
# TP1D Row Linear # TP1D Row Linear
_compute_pattern = ComputePattern.TP1D _compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': 'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
}, },
mode='row', mode='row',
) )
@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': 'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
}, },
mode='col', mode='col',
) )

View File

@ -1,7 +1,5 @@
from .colo_module import ColoModule from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
class ColoLinear(ColoModule): class ColoLinear(ColoModule):
@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
super(ColoLinear, self).__init__() super(ColoLinear, self).__init__()
self._register_shard_params(['weight', 'bias']) self._register_shard_params(['weight', 'bias'])
def register(self, compute_pattern): def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns: if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern: if ComputePattern.TP1D == compute_pattern:
self._set_TP1D() self._set_TP1D(pg)
def _set_TP1D(self): def _set_TP1D(self, pg):
# TP1D Row Linear # TP1D Row Linear
_compute_pattern = ComputePattern.TP1D _compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': 'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], 'bias': None
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'bias':
None
}, },
mode='row', mode='row',
) )
@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': 'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], 'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'bias':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
}, },
mode='col', mode='col',
) )

View File

@ -1,5 +1,5 @@
from typing import Dict from typing import Dict
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
from . import ColoModule from . import ColoModule
import torch import torch
@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
return None return None
def check_colo_module(module: torch.nn.Module, recursive=True): def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
if is_colo_module(module): if is_colo_module(module):
colo_module = get_colo_module(module) colo_module = get_colo_module(module)
param_names = colo_module.get_param_names() param_names = colo_module.get_param_names()
@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
continue continue
if compute_pattern is not None: if compute_pattern is not None:
colo_module.register(compute_pattern) colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern(compute_pattern): if not colo_module.has_compute_pattern(compute_pattern):
raise Exception( raise Exception(
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
if recursive == True: if recursive == True:
for submodule in module.children(): for submodule in module.children():
check_colo_module(submodule, recursive=True) check_colo_module(submodule, pg=pg, recursive=True)
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'): def init_colo_module(module: torch.nn.Module,
compute_spec: ComputeSpec,
pg: ProcessGroup,
recursive=True,
mode='default'):
compute_pattern = compute_spec.compute_pattern compute_pattern = compute_spec.compute_pattern
if is_colo_module(module): if is_colo_module(module):
# for each param # for each param
# set DistSpec and ComputeSpec # set DistSpec and ComputeSpec
colo_module = get_colo_module(module) colo_module = get_colo_module(module)
colo_module.register(compute_pattern) colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
raise NotImplementedError raise NotImplementedError
# a set for modules which update at least one param in the init process. # a set for modules which update at least one param in the init process.
@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
for mod in param.shared_param_modules: for mod in param.shared_param_modules:
modules_update_param.add(mod) modules_update_param.add(mod)
for mod in modules_update_param: for mod in modules_update_param:
check_colo_module(mod, recursive=False) check_colo_module(mod, pg, recursive=False)
if recursive == True: if recursive == True:
for submodule in module.children(): for submodule in module.children():
init_colo_module(submodule, compute_spec, recursive=True, mode=mode) init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)

View File

@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self._tensor_spec.dist_spec.process_group
def get_tp_world_size(self) -> int:
return self._tensor_spec.dist_spec.process_group.tp_world_size()
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:

View File

@ -5,6 +5,7 @@ from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version from packaging import version
from colossalai.logging import get_dist_logger
# TODO(jiaruifang) circle import, move the divide to colossalai.commons. # TODO(jiaruifang) circle import, move the divide to colossalai.commons.
@ -64,7 +65,7 @@ class DistSpecManager:
DistSpecManager._sanity_check(old_dist_spec, dist_spec) DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor chunk = tensor
idx = dist_spec.process_group.rank() idx = dist_spec.process_group.tp_local_rank()
num_parts = prod(dist_spec.num_partitions) num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims): for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i] num_parts //= dist_spec.num_partitions[i]
@ -91,8 +92,9 @@ class DistSpecManager:
saved_dev = tensor.device saved_dev = tensor.device
tensor.data = tensor.data.cuda() tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())] buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group) assert tensor.device.type == 'cuda'
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1): for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = [] new_buffer = []
dim = old_dist_spec.dims[i] dim = old_dist_spec.dims[i]
@ -108,14 +110,14 @@ class DistSpecManager:
@staticmethod @staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
world_size = old_dist_spec.process_group.size() world_size = old_dist_spec.process_group.tp_world_size()
if world_size == 1: if world_size == 1:
return tensor return tensor
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \ assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device and " \ f"collective function, however, we got {tensor.device.type} device and " \
f"{dist.get_backend(old_dist_spec.process_group)} backend" f"{old_dist_spec.process_group.backend} backend"
gather_dim = old_dist_spec.dims[0] gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0] scatter_dim = dist_spec.dims[0]
@ -126,7 +128,7 @@ class DistSpecManager:
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)] scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group) dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous() output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from torch.distributed import ProcessGroup from colossalai.tensor import ProcessGroup
from typing import Optional, List from typing import Optional, List
from numpy import prod from numpy import prod
@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec: def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
assert process_group is not None assert process_group is not None and isinstance(process_group, ProcessGroup)
assert isinstance(dims, list) and isinstance(num_partitions, list) assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions) assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}" assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions)) return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@ -1,5 +1,6 @@
import torch import torch
from typing import List, Optional from typing import List, Optional
from colossalai.logging import get_dist_logger
class ProcessGroup: class ProcessGroup:
@ -41,12 +42,12 @@ class ProcessGroup:
if dp_degree and not tp_degree: if dp_degree and not tp_degree:
self._dp_degree = dp_degree self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size / dp_degree self._tp_degree = self._world_size // dp_degree
if not dp_degree and tp_degree: if not dp_degree and tp_degree:
self._tp_degree = tp_degree self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size / tp_degree self._dp_degree = self._world_size // tp_degree
self._tp_rank_list = [] self._tp_rank_list = []
self._dp_rank_list = [] self._dp_rank_list = []
@ -58,12 +59,48 @@ class ProcessGroup:
if rank_id // self._tp_degree == self._rank // self._tp_degree: if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id) self._tp_rank_list.append(rank_id)
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend) assert backend == 'nccl'
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend) self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
@property
def backend(self):
return self._backend
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
assert False
if self._rank_list != obj._rank_list:
assert False
if self._tp_rank_list != obj._tp_rank_list:
assert False
if self._dp_rank_list != obj._dp_rank_list:
assert False
if self._backend != obj._backend:
assert False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree:
return False
return True
def rank(self):
return self._rank
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def tp_local_rank(self):
return self._rank % self._tp_degree
def dp_local_rank(self):
return self._rank // self._tp_degree
def dp_world_size(self): def dp_world_size(self):
return len(self._dp_rank_list) return len(self._dp_rank_list)

View File

@ -17,11 +17,12 @@ class TensorSpec(object):
self.compute_spec = compute_spec self.compute_spec = compute_spec
self.dist_spec = dist_spec self.dist_spec = dist_spec
# TODO(jiaruifang) actually need tp process group
def get_process_group(self): def get_process_group(self):
return self.dist_spec.process_group return self.dist_spec.process_group
def get_process_group_size(self): def get_process_group_size(self):
return dist.get_world_size(self.dist_spec.process_group) return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
def get_placement(self): def get_placement(self):
return self.dist_spec.placement return self.dist_spec.placement
@ -30,7 +31,7 @@ class TensorSpec(object):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1 or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \ and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.size() == 1) or (self.dist_spec.process_group.tp_world_size() == 1)
def is_shard_1dcol(self): def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \ return self.dist_spec.placement == DistPlacementPattern.SHARD \

View File

@ -15,6 +15,7 @@ import torch.distributed as dist
import os import os
import random import random
import numpy as np import numpy as np
from colossalai.tensor import ProcessGroup
def set_seed(seed): def set_seed(seed):
@ -27,14 +28,16 @@ def set_seed(seed):
def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module) pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size) chunk_manager = ChunkManager(chunk_size)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager) pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, pg)
class Net(torch.nn.Module): class Net(torch.nn.Module):

View File

@ -13,6 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable from typing import Callable
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor import ProcessGroup
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
@ -22,14 +23,16 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module) pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager) pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, process_group=pg)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):

View File

@ -41,7 +41,7 @@ def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1) return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
assert tensor.ndim == shard.ndim assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape: if tensor.shape == shard.shape:
return tensor_equal(tensor, shard) return tensor_equal(tensor, shard)
@ -50,8 +50,10 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
if dims_not_eq.numel() == 1: if dims_not_eq.numel() == 1:
# 1D shard # 1D shard
dim = dims_not_eq.item() dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) if world_size is None:
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
if rank is None:
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -3,14 +3,12 @@ import torch
import pytest import pytest
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from functools import partial from functools import partial
from colossalai.core import global_context as gpc
from _utils import tensor_shard_equal, tensor_equal from _utils import tensor_shard_equal, tensor_equal
@ -38,18 +36,14 @@ class Conv1D(nn.Module):
return x return x
def init_1d_row(weight, bias): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight, bias): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec) bias.set_tensor_spec(spec)
@ -59,7 +53,9 @@ def run_with_spec(spec_init_func):
model = Conv1D(4, 16).cuda() model = Conv1D(4, 16).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
spec_init_func(weight, bias) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
spec_init_func(weight, bias, pg)
x = torch.rand(2, 16).cuda() x = torch.rand(2, 16).cuda()
out = model(x) out = model(x)
colo_out = torch.addmm(bias, x, weight) colo_out = torch.addmm(bias, x, weight)
@ -68,13 +64,12 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
tensor_shard_equal(model.weight.grad, weight.grad) tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
tensor_shard_equal(model.bias.grad, bias.grad) tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row) run_with_spec(init_1d_row)
run_with_spec(init_1d_col) run_with_spec(init_1d_col)

View File

@ -7,12 +7,12 @@ import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, distspec from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
from functools import partial from functools import partial
def run(): def run():
group = _get_default_group() group = ProcessGroup(tp_degree=dist.get_world_size())
rank = dist.get_rank() rank = dist.get_rank()
size = dist.get_world_size() size = dist.get_world_size()
depth = int(math.sqrt(size)) depth = int(math.sqrt(size))
@ -34,7 +34,7 @@ def run():
def check_mem(): def check_mem():
group = _get_default_group() group = ProcessGroup(tp_degree=dist.get_world_size())
size = dist.get_world_size() size = dist.get_world_size()
assert torch.cuda.memory_allocated() == 0 assert torch.cuda.memory_allocated() == 0
x = torch.rand(32, 32).cuda() x = torch.rand(32, 32).cuda()

View File

@ -1,6 +1,5 @@
import torch import torch
from colossalai.context.parallel_mode import ParallelMode from colossalai.tensor import distspec, ColoParameter
from colossalai.tensor import ColoTensor, distspec, ColoParameter
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
@ -10,23 +9,21 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight): def init_1d_col(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.EmbeddingBag(10, 4).cuda() model = torch.nn.EmbeddingBag(10, 4).cuda()
weight = ColoParameter(model.weight.clone()) weight = ColoParameter(model.weight.clone())
spec_init_func(weight) spec_init_func(weight, pg)
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda() inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
offsets = torch.tensor([0, 4]).cuda() offsets = torch.tensor([0, 4]).cuda()
out = model(inputs, offsets=offsets) out = model(inputs, offsets=offsets)
@ -35,7 +32,7 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad) assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -1,5 +1,4 @@
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor, distspec from colossalai.tensor import ColoTensor, distspec
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
@ -11,30 +10,26 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight): def init_1d_row(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight): def init_1d_col(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func, pg: ProcessGroup):
model = torch.nn.Embedding(12, 32).cuda() model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
spec_init_func(weight) spec_init_func(weight, pg)
x = torch.tensor((0, 3, 6, 9)).cuda() x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x) out = model(x)
colo_out = F.embedding(x, weight) colo_out = F.embedding(x, weight)
@ -42,14 +37,16 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad) # compare grad inside a TP group
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) # config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row) pg = ProcessGroup(tp_degree=world_size)
run_with_spec(init_1d_col) run_with_spec(init_1d_row, pg)
run_with_spec(init_1d_col, pg)
@pytest.mark.dist @pytest.mark.dist

View File

@ -1,51 +1,54 @@
import pytest import pytest
import colossalai import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
def init_1d_row_spec(model): def init_1d_row_spec(model, pg: ProcessGroup):
spec = TensorSpec( tensor_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(spec) p.set_tensor_spec(tensor_spec)
def init_1d_col_spec(model): def init_1d_col_spec(model, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(spec) p.set_tensor_spec(spec)
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p, p) assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
assert pg.tp_world_size() is not None
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
def check_grad_equal(model, torch_model): def check_grad_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p.grad, p.grad) assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_gpt(init_spec_func, use_ddp): def run_gpt(init_spec_func, use_ddp):
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp):
model = model.cuda() model = model.cuda()
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
if use_ddp: if use_ddp:
model = ColoDDP(model) # torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
# torch.distributed.barrier()
torch_model = DDP(torch_model, torch_model = DDP(torch_model,
device_ids=[gpc.get_global_rank()], device_ids=[gpc.get_global_rank()],
process_group=gpc.get_group(ParallelMode.DATA)) process_group=gpc.get_group(ParallelMode.DATA))
model = ColoDDP(model, process_group=pg)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p) torch_p.data.copy_(p)
init_spec_func(model) init_spec_func(model, pg)
check_param_equal(model, torch_model) check_param_equal(model, torch_model, pg)
model.train() model.train()
torch_model.train() torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA)) set_seed(pg.tp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, attn_mask) in enumerate(train_dataloader):
logits = model(input_ids, attn_mask) logits = model(input_ids, attn_mask)
torch_logits = torch_model(input_ids, attn_mask) torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits) assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids) torch_loss = criterion(torch_logits, input_ids)
if use_ddp: if use_ddp:
@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp):
else: else:
loss.backward() loss.backward()
torch_loss.backward() torch_loss.backward()
check_grad_equal(model, torch_model) check_grad_equal(model, torch_model, pg)
if i > 0: if i > 0:
break break
@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp):
tp_world_size = world_size // 2 if use_ddp else world_size tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec, use_ddp) # run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp) run_gpt(init_1d_col_spec, use_ddp)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("under development")
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()

View File

@ -1,88 +0,0 @@
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ComputePattern, ComputeSpec
from functools import partial
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.nn.optimizer import ColoOptimizer
import colossalai
import torch
import torch.multiprocessing as mp
import pytest
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.embed = torch.nn.Embedding(20, 4)
self.proj = torch.nn.Linear(4, 8)
def forward(self, x):
# move input to cpu and restore output
current_dev = x.device
x = x.to('cpu')
x = self.embed(x)
x = x.to(current_dev)
x = self.proj(x)
return x
def run_hybrid_device(use_ddp, mode):
with ColoInitContext(device=get_current_device()):
model = Net()
real_model = model
if use_ddp:
model = ColoDDP(model)
real_model = model.module
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode=mode)
# use cpu gloo to handle embedding
real_model.embed.to('cpu')
gloo_group_tp = gpc.get_cpu_group(ParallelMode.PARALLEL_1D)
real_model.embed.weight.spec.dist_spec.process_group = gloo_group_tp
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
out = model(data)
out.sum().backward()
optimizer.step()
def run_dist(rank, world_size, port, use_ddp, mode):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_hybrid_device(use_ddp, mode)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@pytest.mark.parametrize('mode', ['col', 'row'])
@rerun_if_address_is_in_use()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def _test_hybrid_device(world_size, use_ddp, mode):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, mode=mode)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
_test_hybrid_device(4, True, 'row')

View File

@ -12,32 +12,29 @@ import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight, bias): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec) bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.Linear(4, 8).cuda() model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
spec_init_func(weight, bias) spec_init_func(weight, bias, pg)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
out = model(x) out = model(x)
colo_out = F.linear(x, weight, bias) colo_out = F.linear(x, weight, bias)
@ -46,8 +43,8 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad) assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, bias.grad) assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -1,10 +1,12 @@
from colossalai.tensor.colo_parameter import ColoParameter
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
import pytest import pytest
from functools import partial
from _utils import tensor_shard_equal, set_seed
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
from functools import partial
from _utils import tensor_shard_equal, set_seed from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight, pg: ProcessGroup): def init_1d_row_linear(weight, pg: ProcessGroup):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_linear(weight, pg): def init_1d_col_linear(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight, pg): def init_1d_row_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight, pg): def init_1d_col_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name):
with torch.no_grad(): with torch.no_grad():
# check param # check param
for p, torch_p in zip(model.parameters(), model_torch.parameters()): for p, torch_p in zip(model.parameters(), model_torch.parameters()):
assert tensor_shard_equal(torch_p, p) assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
if i > 5: if i > 5:
break break

View File

@ -13,12 +13,10 @@ import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.context.parallel_mode import ParallelMode from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import distspec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def run_model_with_spec(mode, model_name): def run_model_with_spec(mode, model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
rank = pg.rank()
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name):
for p1, p2 in zip(model.parameters(), model_seq.parameters()): for p1, p2 in zip(model.parameters(), model_seq.parameters()):
p2.data.copy_(p1.data) p2.data.copy_(p1.data)
parallel_action = ComputeSpec(ComputePattern.TP1D) compute_spec = ComputeSpec(ComputePattern.TP1D)
# Not all layers in Bert can be mod by 4. # Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if 'bert' == model_name: if 'bert' == model_name:
if 'col' == mode: if 'col' == mode:
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode=mode) init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode) init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.classifier, parallel_action, recursive=True, mode='row') init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
elif 'row' == mode: elif 'row' == mode:
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode='col') init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode) init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.classifier, parallel_action, recursive=True, mode=mode) init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
elif 'simple_net' == model_name: elif 'simple_net' == model_name:
init_colo_module(model, parallel_action, recursive=True, mode=mode) init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
model = model.cuda() model = model.cuda()
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device()) data = data.to(get_current_device())
label = label.to(get_current_device()) label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
if criterion: if criterion:
output = model(data) output = model(data)
@ -113,9 +113,10 @@ def run_linear_with_spec(mode):
model = torch.nn.Linear(4, 8) model = torch.nn.Linear(4, 8)
model_handy = copy(model) model_handy = copy(model)
world_size = torch.distributed.get_world_size()
parallel_action = ComputeSpec(ComputePattern.TP1D) pg = ProcessGroup(tp_degree=world_size)
init_colo_module(model, parallel_action, recursive=True, mode=mode) compute_spec = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
out = model(x) out = model(x)
@ -124,8 +125,8 @@ def run_linear_with_spec(mode):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad) assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad) assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_check_shared_param(): def run_check_shared_param():
@ -136,6 +137,10 @@ def run_check_shared_param():
num_layer = 2 num_layer = 2
vocab_size = 24 vocab_size = 24
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
rank = pg.rank()
config = BertConfig(vocab_size=vocab_size, config = BertConfig(vocab_size=vocab_size,
hidden_size=hidden_dim, hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4, intermediate_size=hidden_dim * 4,
@ -148,18 +153,16 @@ def run_check_shared_param():
model = BertForMaskedLM(config) model = BertForMaskedLM(config)
model = model.cuda() model = model.cuda()
parallel_action = ComputeSpec(ComputePattern.TP1D) compute_spec = ComputeSpec(ComputePattern.TP1D)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
# They are all Linear, so both row is allowed. This should pass check. # They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, parallel_action, recursive=True, mode='row') init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col. # This should be detected by check because you can not set weight as row while set bias as col.
col_spec = TensorSpec( col_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
model.cls.predictions.bias.set_tensor_spec(col_spec) model.cls.predictions.bias.set_tensor_spec(col_spec)
try: try:
check_colo_module(model.cls.predictions.decoder, recursive=False) check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
except Exception as e: except Exception as e:
assert 'incorrectly sharded' in str(e) assert 'incorrectly sharded' in str(e)

View File

@ -4,10 +4,9 @@ import colossalai
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp import torch.multiprocessing as mp
from functools import partial from functools import partial
from colossalai.tensor import ColoTensor, ColoParameter from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.nn import Parameter from torch.nn import Parameter
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec from colossalai.tensor import distspec, TensorSpec
@ -43,9 +42,10 @@ def check_spec_eq(tensor, other):
def check_element_wise_ops(): def check_element_wise_ops():
pg = _get_default_group() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2) t = torch.rand(2, 2)
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.size()]))) x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda()) check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda()) assert torch.equal(x.cuda(), t.cuda())
check_spec_eq(x, torch.abs(x)) check_spec_eq(x, torch.abs(x))

View File

@ -11,7 +11,6 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
from colossalai.context import ParallelMode
from functools import partial from functools import partial
@ -55,11 +54,9 @@ def test_operand():
def _run_view(world_size): def _run_view(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size))) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
t = ColoTensor.from_torch_tensor( t = ColoTensor.from_torch_tensor(
t_ref, t_ref, TensorSpec(distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])))
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
assert t.size_global()[0] == 4 * world_size assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5 assert t.size_global(1) == 5
@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size))) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()]) shard_spec = distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = TensorSpec(shard_spec) tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate())) t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5)) assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
def _run_tensor_replicated_init(world_size): def _run_tensor_replicated_init(world_size):
@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}" assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
def _run_process_group(world_size):
pg1 = ProcessGroup()
pg2 = ProcessGroup()
assert pg1 == pg2
def run_dist_tests(rank, world_size, port): def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size) _run_tensor_shard_init(world_size)
_run_tensor_replicated_init(world_size) _run_tensor_replicated_init(world_size)
_run_view(world_size) _run_view(world_size)
_run_process_group(world_size)
@pytest.mark.dist @pytest.mark.dist

View File

@ -2,13 +2,11 @@ import pytest
import colossalai import colossalai
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager from colossalai.gemini import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.storage().size() > 0: if p.storage().size() > 0:
assert p.dtype == torch.half assert p.dtype == torch.half
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}' assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
pg.tp_world_size()), f'{torch_p} vs {p}'
def check_grad_equal(model, torch_model): def check_grad_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.grad is not None: if p.grad is not None:
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad) assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
return logits return logits
def init_1d_row_spec(model): def init_1d_row_spec(model, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(spec) p.set_tensor_spec(spec)
def init_1d_col_spec(model): def init_1d_col_spec(model, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p) torch_p.data.copy_(p)
world_size = torch.distributed.get_world_size()
# world size, dp = 2, tp =2, construct a hybrid parallelism.
if world_size == 4:
pg = ProcessGroup(tp_degree=2)
else:
pg = ProcessGroup(tp_degree=world_size)
if tp_init_spec_func: if tp_init_spec_func:
tp_init_spec_func(model) tp_init_spec_func(model, pg)
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager, pg)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32) optim = ZeroOptimizer(optim, model, initial_scale=32)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA)) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
print(chunk_manager) # print(chunk_manager)
check_param_equal(model, torch_model) check_param_equal(model, torch_model, pg)
model.train() model.train()
torch_model.train() torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA)) set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask) logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert tensor_equal(logits, torch_logits) assert tensor_equal(logits, torch_logits)
check_grad_equal(model, torch_model) check_grad_equal(model, torch_model, pg)
optim.step() optim.step()
torch_optim.step() torch_optim.step()
check_param_equal(model, torch_model) check_param_equal(model, torch_model, pg)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
if world_size == 4:
config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4: if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec) run_gpt(tp_init_spec_func=init_1d_col_spec)
@ -126,6 +129,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("under development")
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gpt(world_size): def test_gpt(world_size):

View File

@ -1,12 +1,10 @@
import pytest import pytest
import colossalai import colossalai
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from tests.test_tensor._utils import set_seed from tests.test_tensor._utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.tensor import ProcessGroup
def init_zero(model_builder, placement_policy): def init_zero(model_builder, placement_policy):
@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
model.train() model.train()
model_copy.train() model_copy.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA)) pg = ProcessGroup()
set_seed(pg.dp_local_rank())
data_iter = iter(train_dataloader) data_iter = iter(train_dataloader)
data, label = map(lambda x: x.cuda(), next(data_iter)) data, label = map(lambda x: x.cuda(), next(data_iter))

View File

@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy): def init_zero(model, use_chunk, use_zero, placement_policy):
@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(model, gemini_manager) pg = ProcessGroup()
return ZeroDDP(model, gemini_manager, pg)
def run_step(model, optim, criterion, data, label): def run_step(model, optim, criterion, data, label):