[example] update vit example for hybrid parallel plugin (#4641)

* update vit example for hybrid plugin

* reset tp/pp size

* fix dataloader iteration bug

* update optimizer passing in evaluation/add grad_accum

* change criterion

* wrap tqdm

* change grad_accum to grad_checkpoint

* fix pbar
This commit is contained in:
Baizhou Zhang
2023-09-07 17:38:45 +08:00
committed by GitHub
parent 660eed9124
commit 295b38fecf
10 changed files with 246 additions and 192 deletions

View File

@@ -1,9 +1,9 @@
import logging
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -72,18 +72,17 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions is not None:
logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.')
output_attentions = None
if output_hidden_states is not None:
logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.')
output_hidden_states = None
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head