[feat] add optim backward_b_by_grad

This commit is contained in:
duanjunwen
2024-08-29 03:16:59 +00:00
parent b1419ef76a
commit 4c4b01b859
3 changed files with 178 additions and 6 deletions

View File

@@ -58,6 +58,28 @@ class OptimizerWrapper:
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
"""
Performs a backward pass for dx, we only calculate dx = w*dy here
Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): x of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad_tensors,
inputs=inputs,
retain_graph=retain_graph,
)
def backward_w_by_grad():
"""
Performs a backward pass for dw, we only calculate dw = x*dy here
"""
def state_dict(self):
"""
Returns the optimizer state.