mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-06 06:02:16 +00:00
[fix] fix linear (no tp) ops func name;
This commit is contained in:
parent
d2e05a99b3
commit
5f0924361d
@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
|
|||||||
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
@ -11,7 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding1D",
|
"Embedding1D",
|
||||||
"VocabParallelEmbedding1D",
|
"VocabParallelEmbedding1D",
|
||||||
"Linear1D",
|
"LinearWithGradAccum",
|
||||||
"Linear1D_Col",
|
"Linear1D_Col",
|
||||||
"Linear1D_Row",
|
"Linear1D_Row",
|
||||||
"GPT2FusedLinearConv1D_Col",
|
"GPT2FusedLinearConv1D_Col",
|
||||||
|
@ -235,17 +235,16 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class LinearBase(torch.autograd.Function):
|
class LinearWithGradAccum(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Linear layer baseline (no tensor parallel version).
|
Linear layer baseline (no tensor parallel version).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.async_grad_allreduce = async_grad_allreduce
|
ctx.async_grad_allreduce = async_grad_allreduce
|
||||||
ctx.fp8_communication = fp8_communication
|
|
||||||
ctx.use_zbv = use_zbv
|
ctx.use_zbv = use_zbv
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = F.linear(input_, weight, bias)
|
output = F.linear(input_, weight, bias)
|
||||||
@ -258,7 +257,6 @@ class LinearBase(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight, bias = ctx.saved_tensors
|
input, weight, bias = ctx.saved_tensors
|
||||||
use_bias = ctx.use_bias
|
use_bias = ctx.use_bias
|
||||||
ctx.fp8_communication
|
|
||||||
use_zbv = ctx.use_zbv
|
use_zbv = ctx.use_zbv
|
||||||
|
|
||||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||||
@ -1201,8 +1199,8 @@ def linear_with_async_comm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||||
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)
|
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
|
||||||
|
|
||||||
|
|
||||||
def linear_gather_forward_reducescatter_backward(
|
def linear_gather_forward_reducescatter_backward(
|
||||||
|
@ -25,10 +25,10 @@ from colossalai.tensor.d_tensor.api import (
|
|||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_reducescatter_backward,
|
gather_forward_reducescatter_backward,
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
linear_base,
|
|
||||||
linear_gather_forward_reducescatter_backward,
|
linear_gather_forward_reducescatter_backward,
|
||||||
linear_reducescatter_forward_gather_backward,
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
|
linear_with_grad_accum,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
reducescatter_forward_gather_backward,
|
reducescatter_forward_gather_backward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
@ -36,10 +36,10 @@ from ._operation import (
|
|||||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||||
from .utils import create_randomizer_with_offset
|
from .utils import create_randomizer_with_offset
|
||||||
|
|
||||||
__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"]
|
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
|
||||||
|
|
||||||
|
|
||||||
class Linear1D(ParallelModule):
|
class LinearWithGradAccum(ParallelModule):
|
||||||
r"""Linear layer with no parallelism.
|
r"""Linear layer with no parallelism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -69,16 +69,11 @@ class Linear1D(ParallelModule):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
gather_output: bool = False,
|
|
||||||
seq_parallel_mode: str = None,
|
|
||||||
seq_parallel_dim: int = 1,
|
|
||||||
overlap: torch.cuda.Stream = None,
|
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
fp8_communication: bool = False,
|
|
||||||
use_zbv: bool = False,
|
use_zbv: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -87,13 +82,8 @@ class Linear1D(ParallelModule):
|
|||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
|
||||||
self.seq_parallel_mode = seq_parallel_mode
|
|
||||||
self.seq_parallel_dim = seq_parallel_dim
|
|
||||||
self.overlap = overlap
|
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.fp8_communication = fp8_communication
|
|
||||||
self.use_zbv = use_zbv
|
self.use_zbv = use_zbv
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
@ -143,7 +133,7 @@ class Linear1D(ParallelModule):
|
|||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
device = module.weight.device
|
device = module.weight.device
|
||||||
|
|
||||||
linear_1d = Linear1D(
|
linear_1d = LinearWithGradAccum(
|
||||||
in_features=in_features,
|
in_features=in_features,
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -174,12 +164,11 @@ class Linear1D(ParallelModule):
|
|||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
output_parallel = linear_base(
|
output_parallel = linear_with_grad_accum(
|
||||||
input_parallel,
|
input_parallel,
|
||||||
self.weight,
|
self.weight,
|
||||||
bias,
|
bias,
|
||||||
False,
|
False,
|
||||||
fp8_communication=self.fp8_communication,
|
|
||||||
use_zbv=self.use_zbv,
|
use_zbv=self.use_zbv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,10 +52,7 @@ class MixtralPolicy(Policy):
|
|||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
tp_size = self.shard_config.tensor_parallel_size
|
tp_size = self.shard_config.tensor_parallel_size
|
||||||
if self.pipeline_stage_manager:
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
|
||||||
else:
|
|
||||||
use_zbv = False
|
|
||||||
|
|
||||||
# modified for both SP and TP
|
# modified for both SP and TP
|
||||||
num_q_heads = self.model.config.num_attention_heads
|
num_q_heads = self.model.config.num_attention_heads
|
||||||
@ -334,10 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
|||||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
if self.pipeline_stage_manager:
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
|
||||||
else:
|
|
||||||
use_zbv = False
|
|
||||||
# TODO: assign pg mesh from plugin to all modules
|
# TODO: assign pg mesh from plugin to all modules
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for causal lm
|
# add a new item for causal lm
|
||||||
@ -400,10 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
|||||||
from transformers import MixtralForSequenceClassification
|
from transformers import MixtralForSequenceClassification
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
if self.pipeline_stage_manager:
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
|
||||||
else:
|
|
||||||
use_zbv = False
|
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
|
@ -366,10 +366,10 @@ def main():
|
|||||||
)
|
)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
if args.pp_style == "zbv":
|
if args.pp_style == "zbv":
|
||||||
if dist.get_rank() == 0:
|
if coordinator.is_master():
|
||||||
print(f"Step {step} loss: {loss}")
|
print(f"Step {step} loss: {loss}")
|
||||||
else:
|
else:
|
||||||
if dist.get_rank() == dist.get_world_size() - 1:
|
if coordinator.is_last_process():
|
||||||
print(f"Step {step} loss: {loss}")
|
print(f"Step {step} loss: {loss}")
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -9,7 +9,7 @@ from torch.testing import assert_close
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||||
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
|
||||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: b
|
|||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = nn.Linear(32, 128).cuda()
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_base = Linear1D.from_native_module(
|
linear_base = LinearWithGradAccum.from_native_module(
|
||||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
||||||
)
|
)
|
||||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||||
@ -164,7 +164,7 @@ def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool
|
|||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = nn.Linear(32, 128).cuda()
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_base = Linear1D.from_native_module(
|
linear_base = LinearWithGradAccum.from_native_module(
|
||||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
||||||
)
|
)
|
||||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||||
|
Loading…
Reference in New Issue
Block a user