mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[pipeline] set optimizer to optional in execute_pipeline (#4630)
* set optimizer to optional in execute_pipeline * arrange device and mixed precision in booster init * fix execute_pipeline in booster.py
This commit is contained in:
@@ -443,15 +443,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer],
|
||||
optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer]] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
|
||||
# return loss or outputs if needed
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss,
|
||||
outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
|
||||
return_outputs)
|
||||
model.sync_shared_params()
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
|
Reference in New Issue
Block a user