[fix] fix linear (no tp) ops func name;

This commit is contained in:
duanjunwen 2024-10-31 08:18:28 +00:00
parent d2e05a99b3
commit 5f0924361d
6 changed files with 19 additions and 41 deletions

View File

@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
@ -11,7 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D",
"LinearWithGradAccum",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",

View File

@ -235,17 +235,16 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearBase(torch.autograd.Function):
class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""
@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
@ -258,7 +257,6 @@ class LinearBase(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
@ -1201,8 +1199,8 @@ def linear_with_async_comm(
)
def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_gather_forward_reducescatter_backward(

View File

@ -25,10 +25,10 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_base,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
@ -36,10 +36,10 @@ from ._operation import (
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"]
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
class Linear1D(ParallelModule):
class LinearWithGradAccum(ParallelModule):
r"""Linear layer with no parallelism.
Args:
@ -69,16 +69,11 @@ class Linear1D(ParallelModule):
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
**kwargs,
):
@ -87,13 +82,8 @@ class Linear1D(ParallelModule):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias:
@ -143,7 +133,7 @@ class Linear1D(ParallelModule):
bias = module.bias is not None
device = module.weight.device
linear_1d = Linear1D(
linear_1d = LinearWithGradAccum(
in_features=in_features,
out_features=out_features,
bias=bias,
@ -174,12 +164,11 @@ class Linear1D(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_base(
output_parallel = linear_with_grad_accum(
input_parallel,
self.weight,
bias,
False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)

View File

@ -52,10 +52,7 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
@ -334,10 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
@ -400,10 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification

View File

@ -366,10 +366,10 @@ def main():
)
loss = outputs["loss"]
if args.pp_style == "zbv":
if dist.get_rank() == 0:
if coordinator.is_master():
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
if coordinator.is_last_process():
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()

View File

@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module(
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = Linear1D.from_native_module(
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])