[fix] fix linear (no tp) ops func name;

This commit is contained in:
duanjunwen 2024-10-31 08:18:28 +00:00
parent d2e05a99b3
commit 5f0924361d
6 changed files with 19 additions and 41 deletions

View File

@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
@ -11,7 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
__all__ = [ __all__ = [
"Embedding1D", "Embedding1D",
"VocabParallelEmbedding1D", "VocabParallelEmbedding1D",
"Linear1D", "LinearWithGradAccum",
"Linear1D_Col", "Linear1D_Col",
"Linear1D_Row", "Linear1D_Row",
"GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Col",

View File

@ -235,17 +235,16 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearBase(torch.autograd.Function): class LinearWithGradAccum(torch.autograd.Function):
""" """
Linear layer baseline (no tensor parallel version). Linear layer baseline (no tensor parallel version).
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False): def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv ctx.use_zbv = use_zbv
if bias is not None: if bias is not None:
output = F.linear(input_, weight, bias) output = F.linear(input_, weight, bias)
@ -258,7 +257,6 @@ class LinearBase(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
ctx.fp8_communication
use_zbv = ctx.use_zbv use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
@ -1201,8 +1199,8 @@ def linear_with_async_comm(
) )
def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False): def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv) return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_gather_forward_reducescatter_backward( def linear_gather_forward_reducescatter_backward(

View File

@ -25,10 +25,10 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import ( from ._operation import (
gather_forward_reducescatter_backward, gather_forward_reducescatter_backward,
gather_forward_split_backward, gather_forward_split_backward,
linear_base,
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
linear_with_grad_accum,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward, reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
@ -36,10 +36,10 @@ from ._operation import (
from .parallel_module import PaddingParallelModule, ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset
__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"] __all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
class Linear1D(ParallelModule): class LinearWithGradAccum(ParallelModule):
r"""Linear layer with no parallelism. r"""Linear layer with no parallelism.
Args: Args:
@ -69,16 +69,11 @@ class Linear1D(ParallelModule):
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False, use_zbv: bool = False,
**kwargs, **kwargs,
): ):
@ -87,13 +82,8 @@ class Linear1D(ParallelModule):
# Keep input parameters # Keep input parameters
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv self.use_zbv = use_zbv
if skip_bias_add and not bias: if skip_bias_add and not bias:
@ -143,7 +133,7 @@ class Linear1D(ParallelModule):
bias = module.bias is not None bias = module.bias is not None
device = module.weight.device device = module.weight.device
linear_1d = Linear1D( linear_1d = LinearWithGradAccum(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
bias=bias, bias=bias,
@ -174,12 +164,11 @@ class Linear1D(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_base( output_parallel = linear_with_grad_accum(
input_parallel, input_parallel,
self.weight, self.weight,
bias, bias,
False, False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv, use_zbv=self.use_zbv,
) )

View File

@ -52,10 +52,7 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size tp_size = self.shard_config.tensor_parallel_size
if self.pipeline_stage_manager: use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
# modified for both SP and TP # modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads num_q_heads = self.model.config.num_attention_heads
@ -334,10 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy): class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager: use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
# TODO: assign pg mesh from plugin to all modules # TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm # add a new item for causal lm
@ -400,10 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification from transformers import MixtralForSequenceClassification
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager: use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification # add a new item for sequence classification

View File

@ -366,10 +366,10 @@ def main():
) )
loss = outputs["loss"] loss = outputs["loss"]
if args.pp_style == "zbv": if args.pp_style == "zbv":
if dist.get_rank() == 0: if coordinator.is_master():
print(f"Step {step} loss: {loss}") print(f"Step {step} loss: {loss}")
else: else:
if dist.get_rank() == dist.get_world_size() - 1: if coordinator.is_last_process():
print(f"Step {step} loss: {loss}") print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
with ctx: with ctx:
linear_copy = nn.Linear(32, 128).cuda() linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module( linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
) )
assert linear_base.weight.shape == torch.Size([128, 32]) assert linear_base.weight.shape == torch.Size([128, 32])
@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
with ctx: with ctx:
linear_copy = nn.Linear(32, 128).cuda() linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module( linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
) )
assert linear_base.weight.shape == torch.Size([128, 32]) assert linear_base.weight.shape == torch.Size([128, 32])