[feat] update optimizer bwd; ä¸

This commit is contained in:
duanjunwen
2024-09-29 09:59:41 +00:00
parent d63479553c
commit 5c8bbf63a8
5 changed files with 36 additions and 16 deletions

View File

@@ -49,11 +49,11 @@ class OptimizerWrapper:
"""
self.optim.zero_grad(*args, **kwargs)
def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""