mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
parent
dcdd8a5ef7
commit
0a25e16e46
@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import cross_entropy_1d
|
from ..layer import cross_entropy_1d
|
||||||
|
from ..layer._operation import _gather
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||||
@ -288,6 +289,9 @@ class LlamaPipelineForwards:
|
|||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not shard_config.parallel_output:
|
||||||
|
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return (loss,) + output if loss is not None else output
|
return (loss,) + output if loss is not None else output
|
||||||
@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not shard_config.parallel_output:
|
||||||
|
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return (loss,) + output if loss is not None else output
|
return (loss,) + output if loss is not None else output
|
||||||
|
@ -34,6 +34,7 @@ class ShardConfig:
|
|||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
enable_sequence_overlap: bool = False
|
enable_sequence_overlap: bool = False
|
||||||
|
parallel_output = True
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
|
Loading…
Reference in New Issue
Block a user