mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 13:21:47 +00:00
[shardformer] fix GPT2DoubleHeadsModel (#4703)
This commit is contained in:
parent
068372a738
commit
c7d6975d29
@ -94,9 +94,9 @@ class GPT2PipelineForwards:
|
|||||||
if hidden_states is None:
|
if hidden_states is None:
|
||||||
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||||
input_shape = hidden_states.size()[:-1]
|
input_shape = hidden_states.size()[:-1]
|
||||||
batch_size = input_shape[0]
|
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
|
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
# GPT2Attention mask.
|
# GPT2Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user