diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index f9b1d2815..32b9b1b74 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -3,6 +3,8 @@ from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.colo_tensor import ColoTensor from colossalai.context import ParallelMode from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input +from colossalai.nn.layer.utils import divide +from colossalai.core import global_context as gpc from packaging import version @colo_op_impl(torch.nn.functional.linear) @@ -29,10 +31,11 @@ def colo_linear(types, args, kwargs, pg): if weight.shard_spec == None: return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) elif weight.shard_spec == '1Drow': - """ - Input:S[1] x Weight:S[0] = Output:P - All-Reduce(Output) + bias = res - """ + # Input:S[1] x Weight:S[0] = Output:P + # All-Reduce(Output) + bias = res + assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size[-1], \ + 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size) # Input:S[1] input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1) # Output:P diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index f72cd02be..3a567f223 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,7 +1,6 @@ from numpy import product import torch -from typing import Tuple -import numpy +from typing import Tuple, Optional from .op_wrapper import _COLOSSAL_OPS class ColoTensor(object): diff --git a/tests/test_tensor/test_tensor_utils/__init__.py b/tests/test_tensor/_utils/__init__.py similarity index 100% rename from tests/test_tensor/test_tensor_utils/__init__.py rename to tests/test_tensor/_utils/__init__.py diff --git a/tests/test_tensor/test_tensor_utils/_util.py b/tests/test_tensor/_utils/_util.py similarity index 100% rename from tests/test_tensor/test_tensor_utils/_util.py rename to tests/test_tensor/_utils/_util.py diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index a6147463a..bd3adcf8f 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -14,7 +14,7 @@ from colossalai.utils import free_port from colossalai.core import global_context as gpc import torch.distributed as dist -from test_tensor_utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk +from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk def run_linear_tp1d_row_test(): device = get_current_device()