[shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix
This commit is contained in:
flybird11111
2024-03-25 17:21:51 +08:00
committed by GitHub
parent 34e909256c
commit 0688d92e2d
5 changed files with 20 additions and 33 deletions

View File

@@ -16,7 +16,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -279,7 +278,7 @@ class LlamaPipelineForwards:
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
@@ -289,9 +288,6 @@ class LlamaPipelineForwards:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@@ -578,23 +574,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
if not return_dict:
output = (logits,) + outputs[1:]