mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[shardformer]fix gpt2 double head (#4663)
* [shardformer]fix gpt2 test [shardformer]fix gpt2 test [shardformer]fix gpt2 test * fix * [shardformer] add todo * [shardformer] add todo
This commit is contained in:
@@ -78,9 +78,9 @@ class GPT2PipelineForwards:
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, seq_length)
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
@@ -89,13 +89,14 @@ class GPT2PipelineForwards:
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
batch_size = input_shape[0]
|
||||
device = hidden_states.device
|
||||
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
@@ -136,9 +137,9 @@ class GPT2PipelineForwards:
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, seq_length)
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
@@ -721,7 +722,6 @@ def get_gpt2_flash_attention_forward():
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
_, tgt_len, _ = hidden_states.size()
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
|
Reference in New Issue
Block a user