[pipeline] set optimizer to optional in execute_pipeline (#4630)

* set optimizer to optional in execute_pipeline

* arrange device and mixed precision in booster init

* fix execute_pipeline in booster.py
This commit is contained in:
Baizhou Zhang
2023-09-07 10:42:59 +08:00
committed by GitHub
parent c3d5fa3bac
commit 660eed9124
9 changed files with 30 additions and 27 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, Optional
from torch import Tensor
from torch.nn import Module
@@ -14,18 +14,18 @@ class PipelineSchedule:
def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Forward and backward step for pipeline training.
Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.

View File

@@ -237,18 +237,18 @@ class InterleavedSchedule(PipelineSchedule):
def forward_backward_step(self,
model_chunk: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model_chunk (List[Module]): Model Chunk to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
@@ -256,6 +256,8 @@ class InterleavedSchedule(PipelineSchedule):
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)
num_model_chunks = len(model_chunk)

View File

@@ -210,18 +210,18 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
@@ -229,6 +229,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert forward_only, "Optimizer should be passed when doing backward."
self.load_batch(data_iter)