mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-31 03:15:40 +00:00
[fix] fix linear (no tp) ops func name;
This commit is contained in:
parent
d2e05a99b3
commit
5f0924361d
@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
|
||||
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
@ -11,7 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
"VocabParallelEmbedding1D",
|
||||
"Linear1D",
|
||||
"LinearWithGradAccum",
|
||||
"Linear1D_Col",
|
||||
"Linear1D_Row",
|
||||
"GPT2FusedLinearConv1D_Col",
|
||||
|
@ -235,17 +235,16 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class LinearBase(torch.autograd.Function):
|
||||
class LinearWithGradAccum(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer baseline (no tensor parallel version).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
ctx.use_zbv = use_zbv
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
@ -258,7 +257,6 @@ class LinearBase(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
ctx.fp8_communication
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
@ -1201,8 +1199,8 @@ def linear_with_async_comm(
|
||||
)
|
||||
|
||||
|
||||
def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)
|
||||
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
|
@ -25,10 +25,10 @@ from colossalai.tensor.d_tensor.api import (
|
||||
from ._operation import (
|
||||
gather_forward_reducescatter_backward,
|
||||
gather_forward_split_backward,
|
||||
linear_base,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
linear_with_grad_accum,
|
||||
reduce_forward,
|
||||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
@ -36,10 +36,10 @@ from ._operation import (
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"]
|
||||
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
|
||||
|
||||
|
||||
class Linear1D(ParallelModule):
|
||||
class LinearWithGradAccum(ParallelModule):
|
||||
r"""Linear layer with no parallelism.
|
||||
|
||||
Args:
|
||||
@ -69,16 +69,11 @@ class Linear1D(ParallelModule):
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@ -87,13 +82,8 @@ class Linear1D(ParallelModule):
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
@ -143,7 +133,7 @@ class Linear1D(ParallelModule):
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
linear_1d = Linear1D(
|
||||
linear_1d = LinearWithGradAccum(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
@ -174,12 +164,11 @@ class Linear1D(ParallelModule):
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_base(
|
||||
output_parallel = linear_with_grad_accum(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
False,
|
||||
fp8_communication=self.fp8_communication,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
|
||||
|
@ -52,10 +52,7 @@ class MixtralPolicy(Policy):
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
tp_size = self.shard_config.tensor_parallel_size
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# modified for both SP and TP
|
||||
num_q_heads = self.model.config.num_attention_heads
|
||||
@ -334,10 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for causal lm
|
||||
@ -400,10 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||
from transformers import MixtralForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
|
@ -366,10 +366,10 @@ def main():
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if args.pp_style == "zbv":
|
||||
if dist.get_rank() == 0:
|
||||
if coordinator.is_master():
|
||||
print(f"Step {step} loss: {loss}")
|
||||
else:
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
if coordinator.is_last_process():
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
@ -9,7 +9,7 @@ from torch.testing import assert_close
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = Linear1D.from_native_module(
|
||||
linear_base = LinearWithGradAccum.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = Linear1D.from_native_module(
|
||||
linear_base = LinearWithGradAccum.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
|
Loading…
Reference in New Issue
Block a user