[shardformer]fix gpt2 double head (#4663)

* [shardformer]fix gpt2 test

[shardformer]fix gpt2 test

[shardformer]fix gpt2 test

* fix

* [shardformer] add todo

* [shardformer] add todo
This commit is contained in:
flybird11111
2023-09-11 18:35:03 +08:00
committed by GitHub
parent 554aa9592e
commit eedaa3e1ef
5 changed files with 38 additions and 29 deletions

View File

@@ -78,9 +78,9 @@ class GPT2PipelineForwards:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
input_shape = input_ids.size()
input_ids = input_ids.view(-1, seq_length)
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
@@ -89,13 +89,14 @@ class GPT2PipelineForwards:
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
token_type_ids = token_type_ids.view(-1, input_shape[-1])
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
batch_size = input_shape[0]
device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
# GPT2Attention mask.
if attention_mask is not None:
@@ -136,9 +137,9 @@ class GPT2PipelineForwards:
if stage_manager.is_first_stage():
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
position_ids = position_ids.view(-1, input_shape[-1])
else:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if inputs_embeds is None:
@@ -721,7 +722,6 @@ def get_gpt2_flash_attention_forward():
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size()
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):