mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
[fix] fix use_fp8 flag
This commit is contained in:
parent
5b5fbcff09
commit
0218e673db
@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
@ -793,7 +791,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
|
@ -2,7 +2,8 @@ from .albert import *
|
||||
from .bert import *
|
||||
from .blip2 import *
|
||||
from .bloom import *
|
||||
from .chatglm2 import *
|
||||
|
||||
# from .chatglm2 import *
|
||||
from .command import *
|
||||
from .deepseek import *
|
||||
from .falcon import *
|
||||
|
Loading…
Reference in New Issue
Block a user