[refactor] pipeline, put runtime schedule into engine. (#627)

This commit is contained in:
YuliangLiu0306
2022-04-03 20:46:45 +08:00
committed by GitHub
parent e5d615aeee
commit ade05a5d83
9 changed files with 68 additions and 49 deletions

View File

@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
import torch
from typing import Iterable, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
@@ -75,14 +74,14 @@ class BaseSchedule(ABC):
return self._move_to_device(data), self._move_to_device(label)
return data, label
def pre_processing(self, engine: Engine):
def pre_processing(self, engine):
"""To perform actions before running the schedule.
"""
pass
@abstractmethod
def forward_backward_step(self,
engine: Engine,
engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,

View File

@@ -5,7 +5,6 @@ from typing import Iterable
import torch
from colossalai.engine import Engine
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
@@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
"""
def forward_backward_step(self,
engine: Engine,
engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,