[refactor] pipeline, put runtime schedule into engine. (#627)

This commit is contained in:
YuliangLiu0306
2022-04-03 20:46:45 +08:00
committed by GitHub
parent e5d615aeee
commit ade05a5d83
9 changed files with 68 additions and 49 deletions

View File

@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
from asyncio.log import logger
from typing import List
from typing import List, Iterable
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
@@ -10,6 +10,7 @@ from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.logging import get_dist_logger
@@ -27,6 +28,7 @@ class Engine:
clip_grad_norm (float, optional): The norm of gradient clipping.
ophook_list (list): List of ophook.
verbose (bool): whether to display log info.
schedule (''BaseSchedule''): Runtime schedule.
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
@@ -59,7 +61,8 @@ class Engine:
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True):
verbose: bool = True,
schedule: Optional[BaseSchedule] = None):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
@@ -80,6 +83,14 @@ class Engine:
self._ophook_list = []
else:
self._ophook_list = ophook_list
# build schedule
if schedule:
self._schedule = schedule
else:
self._schedule = NonPipelineSchedule()
if self.uses_pipeline:
self._schedule.pre_processing(self)
register_ophooks_recursively(self._model, self._ophook_list)
@property
@@ -102,6 +113,16 @@ class Engine:
"""Criterion attached to the engine"""
return self._criterion
@property
def schedule(self):
"""Schedule attached to the engine"""
return self._schedule
@property
def uses_pipeline(self):
"""show the pipeline parallel used or not"""
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
"""add necessary hook"""
# whether this hook exist
@@ -165,6 +186,16 @@ class Engine:
"""
for handler in self._gradient_handlers:
handler.handle_gradient()
def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
return output, label, loss
def train(self):
"""Sets the model to training mode.
@@ -176,4 +207,4 @@ class Engine:
"""Sets the model to evaluation mode.
"""
self.training = False
self._model.eval()
self._model.eval()

View File

@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
import torch
from typing import Iterable, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
@@ -75,14 +74,14 @@ class BaseSchedule(ABC):
return self._move_to_device(data), self._move_to_device(label)
return data, label
def pre_processing(self, engine: Engine):
def pre_processing(self, engine):
"""To perform actions before running the schedule.
"""
pass
@abstractmethod
def forward_backward_step(self,
engine: Engine,
engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,

View File

@@ -5,7 +5,6 @@ from typing import Iterable
import torch
from colossalai.engine import Engine
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
@@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
"""
def forward_backward_step(self,
engine: Engine,
engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,