mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[fix] fix use_fp8 flag
This commit is contained in:
parent
5b5fbcff09
commit
0218e673db
@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
|
||||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
|
||||||
):
|
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
@ -793,7 +791,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||||||
if ctx.async_grad_reduce_scatter:
|
if ctx.async_grad_reduce_scatter:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
|
@ -2,7 +2,8 @@ from .albert import *
|
|||||||
from .bert import *
|
from .bert import *
|
||||||
from .blip2 import *
|
from .blip2 import *
|
||||||
from .bloom import *
|
from .bloom import *
|
||||||
from .chatglm2 import *
|
|
||||||
|
# from .chatglm2 import *
|
||||||
from .command import *
|
from .command import *
|
||||||
from .deepseek import *
|
from .deepseek import *
|
||||||
from .falcon import *
|
from .falcon import *
|
||||||
|
Loading…
Reference in New Issue
Block a user