Optimize pipeline schedule (#94)

* add pipeline shared module wrapper and update load batch

* added model parallel process group for amp and clip grad (#86)

* added model parallel process group for amp and clip grad

* update amp and clip with model parallel process group

* remove pipeline_prev/next group (#88)

* micro batch offload

* optimize pipeline gpu memory usage

* pipeline can receive tensor shape (#93)

* optimize pipeline gpu memory usage

* fix grad accumulation step counter

* rename classes and functions

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
ver217
2021-12-30 15:56:46 +08:00
committed by GitHub
parent e5b9f9a08d
commit 96780e6ee4
29 changed files with 423 additions and 290 deletions

View File

@@ -155,7 +155,8 @@ class Trainer:
def _train_epoch(self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False):
display_progress: bool = False,
return_output_label: bool = True):
# set training state
self._engine.train()
data_iter = iter(train_dataloader)
@@ -175,7 +176,7 @@ class Trainer:
# run 1 training step
self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=False, return_loss=True)
self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label)
self.engine.step()
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
@@ -197,7 +198,8 @@ class Trainer:
def _eval(self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False):
display_progress: bool = False,
return_output_label: bool = True):
# switch engine status
self._engine.eval()
@@ -220,7 +222,7 @@ class Trainer:
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='Test-step')
logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=True)
self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label)
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
output=(logits, label, loss))
@@ -246,6 +248,7 @@ class Trainer:
test_interval: int = 1,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
"""Trains the model to fit training data.
@@ -256,6 +259,8 @@ class Trainer:
:param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration
:param display_progress: If True, the training progress will be printed
:param return_output_label: If True, the output of model and the label will be returned
:type return_output_label: bool
:type train_dataloader: DataLoader
:type epochs: int
:type max_steps: int
@@ -307,7 +312,8 @@ class Trainer:
self._train_epoch(
train_dataloader=train_dataloader,
epoch=epoch,
display_progress=display_progress
display_progress=display_progress,
return_output_label=return_output_label
)
# start eval
@@ -315,6 +321,7 @@ class Trainer:
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
return_output_label=return_output_label
)
self._cur_epoch += 1
@@ -331,13 +338,16 @@ class Trainer:
def evaluate(self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False):
display_progress: bool = False,
return_output_label: bool = True):
"""Evaluates the model with testing data.
:param test_dataloader: DataLoader in testing
:param display_progress: If True, the evaluation progress will be printed
:param return_output_label: If True, the output of model and the label will be returned
:type test_dataloader: DataLoader
:type display_progress: bool, optional
:type return_output_label: bool
"""
# set display
display_progress = self._should_display_progress(display_progress)
@@ -360,6 +370,7 @@ class Trainer:
# eval
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
return_output_label=return_output_label
)
def predict(self, data: Union[Tensor, List[Tensor]]):
@@ -383,4 +394,4 @@ class Trainer:
data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=False)
return output
return output