mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[hotfix] fix torch 2.0 compatibility (#4936)
* [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit
This commit is contained in:
@@ -100,35 +100,24 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
|
||||
embedding_output = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
hidden_states = embedding_output
|
||||
else:
|
||||
assert (
|
||||
hidden_states is not None
|
||||
), f"Current stage is {stage_manager.stage}, hidden_states should not be None"
|
||||
|
||||
# Go through encoder
|
||||
encoder_outputs = _encoder_forward(
|
||||
encoder=self.encoder,
|
||||
start_idx=stage_index[0],
|
||||
end_idx=stage_index[1],
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
)
|
||||
if not stage_manager.is_last_stage():
|
||||
hidden_states = _encoder_forward(
|
||||
encoder=self.encoder,
|
||||
start_idx=stage_index[0],
|
||||
end_idx=stage_index[1],
|
||||
hidden_states=embedding_output,
|
||||
head_mask=head_mask,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
)
|
||||
return {"hidden_states": hidden_states}
|
||||
else:
|
||||
encoder_outputs = _encoder_forward(
|
||||
encoder=self.encoder,
|
||||
start_idx=stage_index[0],
|
||||
end_idx=stage_index[1],
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
)
|
||||
return {"hidden_states": encoder_outputs}
|
||||
|
||||
# Go through rest layers
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
Reference in New Issue
Block a user