mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
|
||||
from .._base_engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossalai.nn.layer import split_batch
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
@@ -59,7 +59,11 @@ class BaseSchedule(ABC):
|
||||
else:
|
||||
data, label = batch_data
|
||||
|
||||
data, label = self._to_list(data), self._to_list(label)
|
||||
if isinstance(label, (tuple, list)):
|
||||
self.batch_size = label[0].size(0)
|
||||
else:
|
||||
self.batch_size = label.size(0)
|
||||
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def pre_processing(self, engine: Engine):
|
||||
|
Reference in New Issue
Block a user