diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index ce7e8bef6..fe2eb0c99 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor, mat2: ColoTensor, beta: Number = 1, alpha: Number = 1, - *args) -> ColoTensor: + **kargs) -> ColoTensor: """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. This method computes a linear. """ @@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor, assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' ret_tensor = ColoTensor.from_torch_tensor( - tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha), + tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs), spec=ColoTensorSpec(mat2.get_process_group())) elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if mat2.is_shard_1drow() and input_tensor.is_replicate():