mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-23 12:36:03 +00:00
[feat] update optimizer bwd; ä¸
This commit is contained in:
@@ -49,11 +49,11 @@ class OptimizerWrapper:
|
||||
"""
|
||||
self.optim.zero_grad(*args, **kwargs)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||
"""
|
||||
Performs a backward pass on the loss.
|
||||
"""
|
||||
loss.backward(*args, **kwargs)
|
||||
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user