mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-31 11:25:27 +00:00
fix
This commit is contained in:
parent
d7a9eb0f67
commit
2f615a49fd
@ -1056,8 +1056,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
assert (
|
||||
not pp_style == "zbv" or scheduler_nodes is not None
|
||||
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||
# if sp_size is None or sp_size <= 1:
|
||||
# enable_sequence_parallelism = False
|
||||
if enable_sequence_parallelism:
|
||||
self.sequence_parallelism_mode = (
|
||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||
|
@ -515,13 +515,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
if is_share_sp_tp(sp_mode):
|
||||
q_len *= sp_size
|
||||
|
||||
# if sp_mode == "all_to_all":
|
||||
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
# # bsz, q_len, _ = query_states.size()
|
||||
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
@ -548,7 +541,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
@ -607,14 +599,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
# return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights
|
||||
|
||||
return forward
|
||||
|
@ -36,19 +36,10 @@ class LlamaPolicy(Policy):
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
# LlamaFlashAttention2,
|
||||
LlamaModel,
|
||||
# LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
# ATTN_IMPLEMENTATION = {
|
||||
# "eager": LlamaAttention,
|
||||
# "flash_attention_2": LlamaFlashAttention2,
|
||||
# "sdpa": LlamaSdpaAttention,
|
||||
# }
|
||||
policy = {}
|
||||
attn_cls = LlamaAttention
|
||||
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
@ -82,7 +73,7 @@ class LlamaPolicy(Policy):
|
||||
num_kv_heads //= sp_size
|
||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
policy[LlamaAttention] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
@ -91,7 +82,7 @@ class LlamaPolicy(Policy):
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
target_key=LlamaAttention,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager is None:
|
||||
|
@ -225,7 +225,6 @@ class ModelSharder(object):
|
||||
"""
|
||||
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||
held_layers = self.policy.get_held_layers()
|
||||
print("held_layers", held_layers)
|
||||
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||
return set(self._get_recursive_held_layers(held_layers))
|
||||
return None
|
||||
|
@ -16,6 +16,4 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No
|
||||
for n, p in model.named_parameters(recurse=False):
|
||||
setattr(model, n, None)
|
||||
for n, buf in model.named_buffers(recurse=False):
|
||||
import torch
|
||||
print("buffer", n, torch.distributed.get_rank())
|
||||
setattr(model, n, None)
|
||||
|
Loading…
Reference in New Issue
Block a user