Added MoE parallel (#127)

This commit is contained in:
HELSON
2022-01-07 15:08:36 +08:00
committed by GitHub
parent 42741dd4a3
commit dceae85195
26 changed files with 858 additions and 18 deletions

View File

@@ -38,8 +38,9 @@ class BaseSchedule(ABC):
return data
@staticmethod
def _check_sanity(data, tag):
assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict'
def _check_sanity(data, tag: str):
assert isinstance(data, (torch.Tensor, dict)), \
f'{tag} must be torch.Tensor or dict'
def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are