mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-03 12:49:42 +00:00
fix
This commit is contained in:
parent
d7a9eb0f67
commit
2f615a49fd
@ -1056,8 +1056,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
assert (
|
assert (
|
||||||
not pp_style == "zbv" or scheduler_nodes is not None
|
not pp_style == "zbv" or scheduler_nodes is not None
|
||||||
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||||
# if sp_size is None or sp_size <= 1:
|
|
||||||
# enable_sequence_parallelism = False
|
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = (
|
self.sequence_parallelism_mode = (
|
||||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
|
@ -515,13 +515,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
if is_share_sp_tp(sp_mode):
|
if is_share_sp_tp(sp_mode):
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
|
|
||||||
# if sp_mode == "all_to_all":
|
|
||||||
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
|
||||||
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
|
||||||
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
|
||||||
# # bsz, q_len, _ = query_states.size()
|
|
||||||
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
@ -548,7 +541,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
# cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -607,14 +599,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
# return attn_output, attn_weights, past_key_value
|
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
@ -36,19 +36,10 @@ class LlamaPolicy(Policy):
|
|||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaDecoderLayer,
|
LlamaDecoderLayer,
|
||||||
# LlamaFlashAttention2,
|
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
# LlamaSdpaAttention,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ATTN_IMPLEMENTATION = {
|
|
||||||
# "eager": LlamaAttention,
|
|
||||||
# "flash_attention_2": LlamaFlashAttention2,
|
|
||||||
# "sdpa": LlamaSdpaAttention,
|
|
||||||
# }
|
|
||||||
policy = {}
|
policy = {}
|
||||||
attn_cls = LlamaAttention
|
|
||||||
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = VocabParallelEmbedding1D
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
@ -82,7 +73,7 @@ class LlamaPolicy(Policy):
|
|||||||
num_kv_heads //= sp_size
|
num_kv_heads //= sp_size
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[LlamaAttention] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
)
|
)
|
||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
@ -91,7 +82,7 @@ class LlamaPolicy(Policy):
|
|||||||
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=LlamaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pipeline_stage_manager is None:
|
if self.pipeline_stage_manager is None:
|
||||||
|
@ -225,7 +225,6 @@ class ModelSharder(object):
|
|||||||
"""
|
"""
|
||||||
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||||
held_layers = self.policy.get_held_layers()
|
held_layers = self.policy.get_held_layers()
|
||||||
print("held_layers", held_layers)
|
|
||||||
set_tensors_to_none(self.model, exclude=set(held_layers))
|
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||||
return set(self._get_recursive_held_layers(held_layers))
|
return set(self._get_recursive_held_layers(held_layers))
|
||||||
return None
|
return None
|
||||||
|
@ -16,6 +16,4 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No
|
|||||||
for n, p in model.named_parameters(recurse=False):
|
for n, p in model.named_parameters(recurse=False):
|
||||||
setattr(model, n, None)
|
setattr(model, n, None)
|
||||||
for n, buf in model.named_buffers(recurse=False):
|
for n, buf in model.named_buffers(recurse=False):
|
||||||
import torch
|
|
||||||
print("buffer", n, torch.distributed.get_rank())
|
|
||||||
setattr(model, n, None)
|
setattr(model, n, None)
|
||||||
|
Loading…
Reference in New Issue
Block a user