[pipeline] set optimizer to optional in execute_pipeline (#4630)

* set optimizer to optional in execute_pipeline

* arrange device and mixed precision in booster init

* fix execute_pipeline in booster.py
This commit is contained in:
Baizhou Zhang
2023-09-07 10:42:59 +08:00
committed by GitHub
parent c3d5fa3bac
commit 660eed9124
9 changed files with 30 additions and 27 deletions

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Callable, Iterator
from typing import Any, Callable, Iterator, Optional
import torch
@@ -15,7 +15,7 @@ class PipelinePluginBase(Plugin):
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: OptimizerWrapper,
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
pass