From 763dc325f1f1ae364fbc69cd054d10f0dc8c01c8 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Wed, 30 Mar 2022 09:35:46 +0800 Subject: [PATCH] [TP] Add gather_out arg to Linear (#541) --- .../nn/layer/colossalai_layer/linear.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/nn/layer/colossalai_layer/linear.py index 9bebd2b63..f98156500 100644 --- a/colossalai/nn/layer/colossalai_layer/linear.py +++ b/colossalai/nn/layer/colossalai_layer/linear.py @@ -1,4 +1,5 @@ import math +import inspect from typing import Callable from colossalai.utils import get_current_device @@ -78,15 +79,19 @@ class Linear(nn.Module): if self.layer.bias is not None: bias_initializer(self.layer.bias, fan_in=in_features) else: - self.layer = _parallel_linear[tensor_parallel]( - in_features, - out_features, - bias=bias, - dtype=dtype, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - **kwargs, - ) + linear_cls = _parallel_linear[tensor_parallel] + gather_output = kwargs.pop('gather_output', None) + if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs['gather_output'] = gather_output + self.layer = linear_cls( + in_features, + out_features, + bias=bias, + dtype=dtype, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + **kwargs, + ) @property def weight(self):