[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -25,11 +25,12 @@ from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None) -> None:
def __init__(
self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
) -> None:
"""1F1B pipeline schedule.
Args:
@@ -38,8 +39,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super().__init__(stage_manager)
assert num_microbatches is not None or microbatch_size is not None, \
"Either num_microbatches or microbatch_size should be provided"
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
@@ -62,12 +64,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if not self._use_microbatch_size:
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
assert (
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
assert self.batch_size % self.microbatch_size == 0, \
"Batch size should divided by the microbatch size"
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
def load_micro_batch(self) -> Any:
@@ -136,12 +138,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def forward_step(self,
model: Module,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
def forward_step(
self,
model: Module,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
@@ -159,7 +163,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
@@ -169,8 +172,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
else:
return output_obj
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
def backward_step(
self,
optimizer: OptimizerWrapper,
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
) -> Optional[dict]:
"""Backward one step of the pipeline
Args:
@@ -208,13 +216,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
def forward_backward_step(
self,
model: Module,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
@@ -273,7 +283,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
@@ -316,5 +326,5 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if outputs is not None:
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs}
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
return {"loss": accum_loss, "outputs": outputs}