[refactor] pipeline, put runtime schedule into engine. (#627)

This commit is contained in:
YuliangLiu0306
2022-04-03 20:46:45 +08:00
committed by GitHub
parent e5d615aeee
commit ade05a5d83
9 changed files with 68 additions and 49 deletions

View File

@@ -9,7 +9,6 @@ from tqdm import tqdm
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
@@ -23,13 +22,9 @@ class Trainer:
Args:
engine (:class:`Engine`): Engine responsible for the process function.
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
Note:
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
@@ -42,7 +37,7 @@ class Trainer:
>>> # Beginning training progress
>>> timier = ...
>>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier)
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> trainer.fit(
@@ -61,7 +56,6 @@ class Trainer:
def __init__(
self,
engine: Engine,
schedule: BaseSchedule = None,
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
@@ -86,17 +80,6 @@ class Trainer:
# multi-timer for time benchmarking
self._timer = timer
# set schedule which specifies the training iteration for the engine
if schedule is None:
schedule = NonPipelineSchedule()
if (gpc.is_initialized(ParallelMode.PIPELINE)
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
assert not isinstance(
schedule, NonPipelineSchedule
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
self._schedule = schedule
self._schedule.pre_processing(engine)
@property
def cur_epoch(self):
"""Returns the index of the current epoch."""
@@ -129,10 +112,6 @@ class Trainer:
def engine(self):
return self._engine
@property
def schedule(self):
return self._schedule
def _set_current_step(self, epoch: int):
"""Sets current step number.
@@ -203,8 +182,7 @@ class Trainer:
# run 1 training step
self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step(
self.engine,
logits, label, loss = self.engine.execute_schedule(
data_iter,
forward_only=False,
return_loss=True,
@@ -260,8 +238,7 @@ class Trainer:
for _ in progress:
self._call_hooks("before_test_iter")
self._call_timer(action="start", item="Test-step")
logits, label, loss = self.schedule.forward_backward_step(
self.engine,
logits, label, loss = self.engine.execute_schedule(
data_iter,
forward_only=True,
return_loss=True,
@@ -449,8 +426,7 @@ class Trainer:
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step(self.engine,
data_iter,
forward_only=True,
return_loss=False)
output, _, _ = self.engine.execute_schedule(data_iter,
forward_only=True,
return_loss=False)
return output