diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 47e693720..6f82f2c07 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -24,6 +24,13 @@ class ColoTensor(object): for kwarg in kwargs.values(): if isinstance(kwarg, ColoTensor): return _COLOSSAL_OPS[func](types, args, kwargs, None) + else: + # If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors. + args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args] + if kwargs is None: + kwargs = {} - raise RuntimeError(f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for ColoTensor!") + kwargs = { + kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs + } + return func(*args, **kwargs)