[shardformer]gather llama logits (#5398)

* gather llama logits

* fix
This commit is contained in:
flybird11111
2024-02-27 22:44:07 +08:00
committed by GitHub
parent dcdd8a5ef7
commit 0a25e16e46
2 changed files with 8 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import _gather
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -288,6 +289,9 @@ class LlamaPipelineForwards:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output