mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -47,7 +47,8 @@ class BaseSchedule(ABC):
|
||||
data = {k: self._move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}"
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_batch_size(self, data):
|
||||
@@ -72,7 +73,7 @@ class BaseSchedule(ABC):
|
||||
Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label).
|
||||
"""
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
raise RuntimeError("Dataloader is not defined.")
|
||||
batch_data = next(data_iter)
|
||||
|
||||
if to_gpu:
|
||||
@@ -81,17 +82,17 @@ class BaseSchedule(ABC):
|
||||
return batch_data
|
||||
|
||||
def pre_processing(self, engine):
|
||||
"""To perform actions before running the schedule.
|
||||
"""
|
||||
pass
|
||||
"""To perform actions before running the schedule."""
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True):
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
@@ -101,7 +102,6 @@ class BaseSchedule(ABC):
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine, inputs):
|
||||
@@ -113,13 +113,14 @@ class BaseSchedule(ABC):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}")
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine, outputs, labels):
|
||||
assert isinstance(outputs,
|
||||
(torch.Tensor, list, tuple,
|
||||
dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
@@ -134,6 +135,8 @@ class BaseSchedule(ABC):
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)")
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
Reference in New Issue
Block a user