[fix] fix use_fp8 flag

This commit is contained in:
duanjunwen 2024-11-01 07:05:24 +00:00
parent 5b5fbcff09
commit 0218e673db
2 changed files with 4 additions and 5 deletions

View File

@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward( def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
@ -793,7 +791,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
if ctx.async_grad_reduce_scatter: if ctx.async_grad_reduce_scatter:
handle.wait() handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None, None return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):

View File

@ -2,7 +2,8 @@ from .albert import *
from .bert import * from .bert import *
from .blip2 import * from .blip2 import *
from .bloom import * from .bloom import *
from .chatglm2 import *
# from .chatglm2 import *
from .command import * from .command import *
from .deepseek import * from .deepseek import *
from .falcon import * from .falcon import *