mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -37,19 +37,22 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 1, \
|
||||
'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \
|
||||
'which is a tuple of tensors for the current batch, ' \
|
||||
'i.e. data_process_func(dataloader_output).'
|
||||
assert len(sig.parameters) == 1, (
|
||||
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
||||
"which is a tuple of tensors for the current batch, "
|
||||
"i.e. data_process_func(dataloader_output)."
|
||||
)
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True):
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
@@ -64,8 +67,9 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
batch_data = self.load_batch(data_iter)
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
|
Reference in New Issue
Block a user