mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 13:45:51 +00:00
upgrade llama
This commit is contained in:
parent
0c5ed65305
commit
686982764c
@ -1056,6 +1056,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
assert (
|
assert (
|
||||||
not pp_style == "zbv" or scheduler_nodes is not None
|
not pp_style == "zbv" or scheduler_nodes is not None
|
||||||
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||||
|
if sp_size is None or sp_size <= 1:
|
||||||
|
enable_sequence_parallelism = False
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = (
|
self.sequence_parallelism_mode = (
|
||||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
|
@ -94,6 +94,7 @@ class LlamaPipelineForwards:
|
|||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
|
|
||||||
# Support SP + PP
|
# Support SP + PP
|
||||||
sp_mode = shard_config.sequence_parallelism_mode
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
@ -112,6 +113,7 @@ class LlamaPipelineForwards:
|
|||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)
|
||||||
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length + past_seen_tokens
|
seq_length_with_past = seq_length + past_seen_tokens
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -141,7 +143,7 @@ class LlamaPipelineForwards:
|
|||||||
invert=(sp_mode != "ring_attn"),
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
|
||||||
|
|
||||||
# Support SP + PP. Later stages have already received the split input.
|
# Support SP + PP. Later stages have already received the split input.
|
||||||
split_input = disable_pp or stage_manager.is_first_stage()
|
split_input = disable_pp or stage_manager.is_first_stage()
|
||||||
@ -177,6 +179,7 @@ class LlamaPipelineForwards:
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1])
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
num_ckpt_layers = 0
|
num_ckpt_layers = 0
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
@ -204,6 +207,7 @@ class LlamaPipelineForwards:
|
|||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
|
position_embeddings
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@ -214,6 +218,7 @@ class LlamaPipelineForwards:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@ -486,8 +491,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
|
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
@ -505,30 +510,21 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
# sp: modify sp_len when sequence parallel mode is ring
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
if is_share_sp_tp(sp_mode):
|
if is_share_sp_tp(sp_mode):
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
# if sp_mode == "all_to_all":
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
# # query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
query_slices = self.q_proj.weight.split(
|
# # key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
# # value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
)
|
# # bsz, q_len, _ = query_states.size()
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
# # hidden_states = all_to_all_comm(hidden_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
query_states = self.q_proj(hidden_states)
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
@ -537,9 +533,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -552,7 +548,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
# cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -610,17 +607,13 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
attn_output = self.o_proj(attn_output)
|
||||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
||||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
return attn_output, attn_weights, past_key_value
|
# return attn_output, attn_weights, past_key_value
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
@ -36,19 +36,19 @@ class LlamaPolicy(Policy):
|
|||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaDecoderLayer,
|
LlamaDecoderLayer,
|
||||||
LlamaFlashAttention2,
|
# LlamaFlashAttention2,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaSdpaAttention,
|
# LlamaSdpaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
ATTN_IMPLEMENTATION = {
|
# ATTN_IMPLEMENTATION = {
|
||||||
"eager": LlamaAttention,
|
# "eager": LlamaAttention,
|
||||||
"flash_attention_2": LlamaFlashAttention2,
|
# "flash_attention_2": LlamaFlashAttention2,
|
||||||
"sdpa": LlamaSdpaAttention,
|
# "sdpa": LlamaSdpaAttention,
|
||||||
}
|
# }
|
||||||
policy = {}
|
policy = {}
|
||||||
|
attn_cls = LlamaAttention
|
||||||
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
# attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
embedding_cls = VocabParallelEmbedding1D
|
embedding_cls = VocabParallelEmbedding1D
|
||||||
@ -354,6 +354,7 @@ class LlamaPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
|
held_layers.append(module.rotary_emb)
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
|
@ -225,6 +225,7 @@ class ModelSharder(object):
|
|||||||
"""
|
"""
|
||||||
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||||
held_layers = self.policy.get_held_layers()
|
held_layers = self.policy.get_held_layers()
|
||||||
|
print("held_layers", held_layers)
|
||||||
set_tensors_to_none(self.model, exclude=set(held_layers))
|
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||||
return set(self._get_recursive_held_layers(held_layers))
|
return set(self._get_recursive_held_layers(held_layers))
|
||||||
return None
|
return None
|
||||||
|
@ -16,4 +16,6 @@ def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> No
|
|||||||
for n, p in model.named_parameters(recurse=False):
|
for n, p in model.named_parameters(recurse=False):
|
||||||
setattr(model, n, None)
|
setattr(model, n, None)
|
||||||
for n, buf in model.named_buffers(recurse=False):
|
for n, buf in model.named_buffers(recurse=False):
|
||||||
|
import torch
|
||||||
|
print("buffer", n, torch.distributed.get_rank())
|
||||||
setattr(model, n, None)
|
setattr(model, n, None)
|
||||||
|
@ -162,9 +162,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
[
|
[
|
||||||
# Double Ring Attention
|
# Double Ring Attention
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"sp_size": 4,
|
"sp_size": 2,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "ring_attn",
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
@ -226,12 +226,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"sp_size": 1,
|
"sp_size": 2,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "ring",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
"enable_flash_attention": True,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
|
Loading…
Reference in New Issue
Block a user