mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[refactor] pipeline, put runtime schedule into engine. (#627)
This commit is contained in:
@@ -9,7 +9,6 @@ from tqdm import tqdm
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
@@ -23,13 +22,9 @@ class Trainer:
|
||||
|
||||
Args:
|
||||
engine (:class:`Engine`): Engine responsible for the process function.
|
||||
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
|
||||
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
|
||||
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
|
||||
|
||||
Note:
|
||||
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
|
||||
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
|
||||
|
||||
Examples:
|
||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||
@@ -42,7 +37,7 @@ class Trainer:
|
||||
>>> # Beginning training progress
|
||||
>>> timier = ...
|
||||
>>> logger = ...
|
||||
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier)
|
||||
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||
>>> # add hooks you would like to use here.
|
||||
>>> hook_list = []
|
||||
>>> trainer.fit(
|
||||
@@ -61,7 +56,6 @@ class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
schedule: BaseSchedule = None,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
@@ -86,17 +80,6 @@ class Trainer:
|
||||
# multi-timer for time benchmarking
|
||||
self._timer = timer
|
||||
|
||||
# set schedule which specifies the training iteration for the engine
|
||||
if schedule is None:
|
||||
schedule = NonPipelineSchedule()
|
||||
if (gpc.is_initialized(ParallelMode.PIPELINE)
|
||||
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
|
||||
assert not isinstance(
|
||||
schedule, NonPipelineSchedule
|
||||
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
|
||||
self._schedule = schedule
|
||||
self._schedule.pre_processing(engine)
|
||||
|
||||
@property
|
||||
def cur_epoch(self):
|
||||
"""Returns the index of the current epoch."""
|
||||
@@ -129,10 +112,6 @@ class Trainer:
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
return self._schedule
|
||||
|
||||
def _set_current_step(self, epoch: int):
|
||||
"""Sets current step number.
|
||||
|
||||
@@ -203,8 +182,7 @@ class Trainer:
|
||||
|
||||
# run 1 training step
|
||||
self.engine.zero_grad()
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine,
|
||||
logits, label, loss = self.engine.execute_schedule(
|
||||
data_iter,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
@@ -260,8 +238,7 @@ class Trainer:
|
||||
for _ in progress:
|
||||
self._call_hooks("before_test_iter")
|
||||
self._call_timer(action="start", item="Test-step")
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine,
|
||||
logits, label, loss = self.engine.execute_schedule(
|
||||
data_iter,
|
||||
forward_only=True,
|
||||
return_loss=True,
|
||||
@@ -449,8 +426,7 @@ class Trainer:
|
||||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.schedule.forward_backward_step(self.engine,
|
||||
data_iter,
|
||||
forward_only=True,
|
||||
return_loss=False)
|
||||
output, _, _ = self.engine.execute_schedule(data_iter,
|
||||
forward_only=True,
|
||||
return_loss=False)
|
||||
return output
|
||||
|
Reference in New Issue
Block a user