diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a4a8c81ae..1e0f7be24 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1056,8 +1056,6 @@ class HybridParallelPlugin(PipelinePluginBase): assert ( not pp_style == "zbv" or scheduler_nodes is not None ), f"scheduler_nodes must not be None when using zero bubble pipeline." - # if sp_size is None or sp_size <= 1: - # enable_sequence_parallelism = False if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ee8cfc80f..7aadc227e 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -515,13 +515,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s if is_share_sp_tp(sp_mode): q_len *= sp_size - # if sp_mode == "all_to_all": - # # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) - # # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) - # # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) - # # bsz, q_len, _ = query_states.size() - # # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication) - query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -548,7 +541,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -607,14 +599,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - # attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - # return attn_output, attn_weights, past_key_value return attn_output, attn_weights return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ae718dd94..1431a9c70 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,19 +36,10 @@ class LlamaPolicy(Policy): from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, - # LlamaFlashAttention2, LlamaModel, - # LlamaSdpaAttention, ) - # ATTN_IMPLEMENTATION = { - # "eager": LlamaAttention, - # "flash_attention_2": LlamaFlashAttention2, - # "sdpa": LlamaSdpaAttention, - # } policy = {} - attn_cls = LlamaAttention - # attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -82,7 +73,7 @@ class LlamaPolicy(Policy): num_kv_heads //= sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads - policy[attn_cls] = ModulePolicyDescription( + policy[LlamaAttention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: @@ -91,7 +82,7 @@ class LlamaPolicy(Policy): "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=LlamaAttention, ) if self.pipeline_stage_manager is None: diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index f3997a158..ee2f1f405 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -225,7 +225,6 @@ class ModelSharder(object): """ if self.shard_config and self.shard_config.pipeline_stage_manager: held_layers = self.policy.get_held_layers() - print("held_layers", held_layers) set_tensors_to_none(self.model, exclude=set(held_layers)) return set(self._get_recursive_held_layers(held_layers)) return None diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py index 5ae7e9de7..2bac37bfe 100644 --- a/colossalai/shardformer/shard/utils.py +++ b/colossalai/shardformer/shard/utils.py @@ -16,6 +16,4 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No for n, p in model.named_parameters(recurse=False): setattr(model, n, None) for n, buf in model.named_buffers(recurse=False): - import torch - print("buffer", n, torch.distributed.get_rank()) setattr(model, n, None)