mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -16,11 +16,11 @@ from .base import PipelineSchedule
|
||||
|
||||
|
||||
class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
|
||||
self.num_model_chunks = num_model_chunks
|
||||
assert num_microbatches % self.num_model_chunks == 0, \
|
||||
"Number of microbatches should be an integer multiple of number of model chunks"
|
||||
assert (
|
||||
num_microbatches % self.num_model_chunks == 0
|
||||
), "Number of microbatches should be an integer multiple of number of model chunks"
|
||||
super().__init__(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
@@ -42,8 +42,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
||||
@@ -72,7 +71,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not forward:
|
||||
model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def is_first_stage(self, model_chunk_id: int) -> bool:
|
||||
@@ -161,13 +160,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
if not self.is_first_stage(model_chunk_id):
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
|
||||
def forward_step(self,
|
||||
model_chunk: Module,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
|
||||
def forward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model (Module): Model Chunk to be run
|
||||
@@ -195,8 +196,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
|
||||
def backward_step(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
) -> Optional[dict]:
|
||||
"""Backward one step of the pipeline
|
||||
|
||||
Args:
|
||||
@@ -235,13 +241,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(self,
|
||||
model_chunk: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
@@ -321,7 +329,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
@@ -369,4 +377,4 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {'loss': accum_loss, 'outputs': outputs}
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
Reference in New Issue
Block a user