mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer import dist_cross_entropy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -330,30 +330,14 @@ class OPTPipelineForwards:
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.decoder.dtype,
|
||||
)
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
loss = dist_cross_entropy(
|
||||
labels,
|
||||
logits,
|
||||
shard_config,
|
||||
self.lm_head.out_features,
|
||||
self.config.vocab_size,
|
||||
self.model.decoder.dtype,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
dtype=self.model.decoder.dtype,
|
||||
)
|
||||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
Reference in New Issue
Block a user