[shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert
This commit is contained in:
flybird11111
2024-03-18 15:55:11 +08:00
committed by GitHub
parent f2e8b9ef9f
commit 5e16bf7980
6 changed files with 32 additions and 13 deletions

View File

@@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import _gather
from ..layer._operation import gather_forward_split_backward
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -290,7 +290,7 @@ class LlamaPipelineForwards:
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -485,8 +485,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
@@ -593,7 +594,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]