[fix] fix use_fp8 flag

This commit is contained in:
duanjunwen 2024-11-01 07:05:24 +00:00
parent 5b5fbcff09
commit 0218e673db
2 changed files with 4 additions and 5 deletions

View File

@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
@ -793,7 +791,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):

View File

@ -2,7 +2,8 @@ from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm2 import *
# from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *