[hotfix] make Gemini work for conv DNN (#1998)

This commit is contained in:
Jiarui Fang 2022-11-22 14:52:36 +08:00 committed by GitHub
parent 155891113e
commit a2d3266648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,11 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec
from ._utils import GeneralTensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
def register_elementwise_op(op):
@ -15,8 +17,13 @@ def register_elementwise_op(op):
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
if 'inplace' in kwargs:
# TODO(jiaruifang) inplace will cause bugs
input_tensor = input_tensor.clone()
return op(input_tensor, *args, **kwargs)
else:
output = op(input_tensor, *args, **kwargs)
# return output
if isinstance(input_tensor, ColoTensor):
if isinstance(output, str):
return output