diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index c4cf3fb85..134558b6a 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -110,4 +110,6 @@ class DistCrossEntropy(Function): def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None ) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) + diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b9342df04..60bc8b711 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -24,10 +24,7 @@ from transformers.models.llama.modeling_llama import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d @@ -564,7 +561,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - #attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + # attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -824,4 +821,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): attentions=outputs.attentions, ) - return forward \ No newline at end of file + return forward