mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[doc] add booster docstring and fix autodoc (#3789)
* [doc] add docstr for booster methods * [doc] fix autodoc
This commit is contained in:
@@ -130,6 +130,12 @@ class Booster:
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||
"""Backward pass.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss to be backpropagated.
|
||||
optimizer (Optimizer): The optimizer to be updated.
|
||||
"""
|
||||
# TODO: implement this method with plugin
|
||||
optimizer.backward(loss)
|
||||
|
||||
@@ -146,6 +152,14 @@ class Booster:
|
||||
pass
|
||||
|
||||
def no_sync(self, model: nn.Module) -> contextmanager:
|
||||
"""Context manager to disable gradient synchronization across DP process groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be disabled gradient synchronization.
|
||||
|
||||
Returns:
|
||||
contextmanager: Context to disable gradient synchronization.
|
||||
"""
|
||||
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
|
Reference in New Issue
Block a user