mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[example] update vit example for hybrid parallel plugin (#4641)
* update vit example for hybrid plugin * reset tp/pp size * fix dataloader iteration bug * update optimizer passing in evaluation/add grad_accum * change criterion * wrap tqdm * change grad_accum to grad_checkpoint * fix pbar
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
@@ -72,18 +72,17 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions is not None:
|
||||
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
|
||||
output_attentions = None
|
||||
if output_hidden_states is not None:
|
||||
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = None
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
Reference in New Issue
Block a user