Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
@@ -59,7 +59,11 @@ class BaseSchedule(ABC):
else:
data, label = batch_data
data, label = self._to_list(data), self._to_list(label)
if isinstance(label, (tuple, list)):
self.batch_size = label[0].size(0)
else:
self.batch_size = label.size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
return self._move_to_device(data), self._move_to_device(label)
def pre_processing(self, engine: Engine):