[shardformer/sequence parallel] not support opt of seq-parallel, add warning and fix a bug in gpt2 pp (#4488)

This commit is contained in:
Bin Jia
2023-08-22 17:35:35 +08:00
committed by GitHub
parent 5545114fd8
commit 351351a36e
2 changed files with 5 additions and 1 deletions

View File

@@ -148,7 +148,7 @@ class GPT2PipelineForwards:
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)