mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[shardformer] fix gathering output when using tensor parallelism (#5431)
* fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert
This commit is contained in:
@@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import _gather
|
||||
from ..layer._operation import gather_forward_split_backward
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
@@ -290,7 +290,7 @@ class LlamaPipelineForwards:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -485,8 +485,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
if not hasattr(self, "attention"):
|
||||
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = self.attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -593,7 +594,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
Reference in New Issue
Block a user