diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index d91807607..8c2e6e7c5 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -723,9 +723,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward( - ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication - ): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -793,7 +791,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): if ctx.async_grad_reduce_scatter: handle.wait() - return output, grad_weight, grad_bias, None, None, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4adc38619..029968231 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,8 @@ from .albert import * from .bert import * from .blip2 import * from .bloom import * -from .chatglm2 import * + +# from .chatglm2 import * from .command import * from .deepseek import * from .falcon import *