diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7aadc227e..fe102eecf 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -94,7 +93,6 @@ class LlamaPipelineForwards: batch_size, seq_length = input_shape device = hidden_states.device - # Support SP + PP sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group @@ -113,7 +111,6 @@ class LlamaPipelineForwards: raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) - seq_length_with_past = seq_length + past_seen_tokens if output_attentions: @@ -143,7 +140,9 @@ class LlamaPipelineForwards: invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values + ) # Support SP + PP. Later stages have already received the split input. split_input = disable_pp or stage_manager.is_first_stage() @@ -207,7 +206,7 @@ class LlamaPipelineForwards: output_attentions, use_cache, cache_position, - position_embeddings + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -218,7 +217,7 @@ class LlamaPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1431a9c70..9ad63dd7f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -33,11 +33,7 @@ class LlamaPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaModel, - ) + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel policy = {} embedding_cls = None