mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[refactor] pipeline, put runtime schedule into engine. (#627)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from asyncio.log import logger
|
||||
from typing import List
|
||||
from typing import List, Iterable
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
@@ -10,6 +10,7 @@ from torch.optim import Optimizer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
|
||||
from typing import Optional, Type
|
||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -27,6 +28,7 @@ class Engine:
|
||||
clip_grad_norm (float, optional): The norm of gradient clipping.
|
||||
ophook_list (list): List of ophook.
|
||||
verbose (bool): whether to display log info.
|
||||
schedule (''BaseSchedule''): Runtime schedule.
|
||||
|
||||
Examples:
|
||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||
@@ -59,7 +61,8 @@ class Engine:
|
||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True):
|
||||
verbose: bool = True,
|
||||
schedule: Optional[BaseSchedule] = None):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._criterion = criterion
|
||||
@@ -80,6 +83,14 @@ class Engine:
|
||||
self._ophook_list = []
|
||||
else:
|
||||
self._ophook_list = ophook_list
|
||||
|
||||
# build schedule
|
||||
if schedule:
|
||||
self._schedule = schedule
|
||||
else:
|
||||
self._schedule = NonPipelineSchedule()
|
||||
if self.uses_pipeline:
|
||||
self._schedule.pre_processing(self)
|
||||
register_ophooks_recursively(self._model, self._ophook_list)
|
||||
|
||||
@property
|
||||
@@ -102,6 +113,16 @@ class Engine:
|
||||
"""Criterion attached to the engine"""
|
||||
return self._criterion
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
"""Schedule attached to the engine"""
|
||||
return self._schedule
|
||||
|
||||
@property
|
||||
def uses_pipeline(self):
|
||||
"""show the pipeline parallel used or not"""
|
||||
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))
|
||||
|
||||
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||
"""add necessary hook"""
|
||||
# whether this hook exist
|
||||
@@ -165,6 +186,16 @@ class Engine:
|
||||
"""
|
||||
for handler in self._gradient_handlers:
|
||||
handler.handle_gradient()
|
||||
|
||||
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
||||
"""Run the forward, loss computation, and backward for the model.
|
||||
Returns a tuple of (output, label, loss).
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||
"""
|
||||
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
|
||||
return output, label, loss
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode.
|
||||
@@ -176,4 +207,4 @@ class Engine:
|
||||
"""Sets the model to evaluation mode.
|
||||
"""
|
||||
self.training = False
|
||||
self._model.eval()
|
||||
self._model.eval()
|
||||
|
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
from typing import Iterable, Callable
|
||||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
@@ -75,14 +74,14 @@ class BaseSchedule(ABC):
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return data, label
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
def pre_processing(self, engine):
|
||||
"""To perform actions before running the schedule.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
|
@@ -5,7 +5,6 @@ from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from ._base_schedule import BaseSchedule
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
@@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
"""
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
|
Reference in New Issue
Block a user