[hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit
This commit is contained in:
Hongxin Liu
2023-10-18 11:05:25 +08:00
committed by GitHub
parent 21ba89cab6
commit 1f5d2e8062
6 changed files with 39 additions and 55 deletions

View File

@@ -100,35 +100,24 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
hidden_states = embedding_output
else:
assert (
hidden_states is not None
), f"Current stage is {stage_manager.stage}, hidden_states should not be None"
# Go through encoder
encoder_outputs = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=hidden_states,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
if not stage_manager.is_last_stage():
hidden_states = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=embedding_output,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
return {"hidden_states": hidden_states}
else:
encoder_outputs = _encoder_forward(
encoder=self.encoder,
start_idx=stage_index[0],
end_idx=stage_index[1],
hidden_states=hidden_states,
head_mask=head_mask,
return_dict=return_dict,
stage_manager=stage_manager,
)
return {"hidden_states": encoder_outputs}
# Go through rest layers
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None