[example] update vit example for hybrid parallel plugin (#4641)

* update vit example for hybrid plugin

* reset tp/pp size

* fix dataloader iteration bug

* update optimizer passing in evaluation/add grad_accum

* change criterion

* wrap tqdm

* change grad_accum to grad_checkpoint

* fix pbar
This commit is contained in:
Baizhou Zhang
2023-09-07 17:38:45 +08:00
committed by GitHub
parent 660eed9124
commit 295b38fecf
10 changed files with 246 additions and 192 deletions

View File

@@ -884,6 +884,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if self.gradient_checkpointing and self.training:
if use_cache:
logger = logging.get_logger(__name__)
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False