mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[pipeline] set optimizer to optional in execute_pipeline (#4630)
* set optimizer to optional in execute_pipeline * arrange device and mixed precision in booster init * fix execute_pipeline in booster.py
This commit is contained in:
@@ -237,18 +237,18 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
def forward_backward_step(self,
|
||||
model_chunk: Module,
|
||||
optimizer: OptimizerWrapper,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
model_chunk (List[Module]): Model Chunk to be trained.
|
||||
optimizer (OptimizerWrapper): Optimizer to be used.
|
||||
data_iter (Iterable): Data iterator.
|
||||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
|
||||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
|
||||
|
||||
@@ -256,6 +256,8 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
dict: A dict with keys: 'loss' and 'outputs'.
|
||||
"""
|
||||
forward_only = not torch.is_grad_enabled()
|
||||
if optimizer is None:
|
||||
assert forward_only, "Optimizer should be passed when doing backward."
|
||||
|
||||
self.load_batch(data_iter)
|
||||
num_model_chunks = len(model_chunk)
|
||||
|
Reference in New Issue
Block a user