mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[refactor] pipeline, put runtime schedule into engine. (#627)
This commit is contained in:
@@ -6,7 +6,6 @@ from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
from typing import Iterable, Callable
|
||||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
@@ -75,14 +74,14 @@ class BaseSchedule(ABC):
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return data, label
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
def pre_processing(self, engine):
|
||||
"""To perform actions before running the schedule.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
|
@@ -5,7 +5,6 @@ from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from ._base_schedule import BaseSchedule
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
@@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
"""
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
|
Reference in New Issue
Block a user