[doc] add booster docstring and fix autodoc (#3789)

* [doc] add docstr for booster methods

* [doc] fix autodoc
This commit is contained in:
Hongxin Liu
2023-05-22 10:56:47 +08:00
committed by GitHub
parent 3c07a2846e
commit 72688adb2f
5 changed files with 16 additions and 70 deletions

View File

@@ -130,6 +130,12 @@ class Booster:
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
"""Backward pass.
Args:
loss (torch.Tensor): The loss to be backpropagated.
optimizer (Optimizer): The optimizer to be updated.
"""
# TODO: implement this method with plugin
optimizer.backward(loss)
@@ -146,6 +152,14 @@ class Booster:
pass
def no_sync(self, model: nn.Module) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Args:
model (nn.Module): The model to be disabled gradient synchronization.
Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)