mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-30 17:22:21 +00:00
[hotfix] make Gemini work for conv DNN (#1998)
This commit is contained in:
parent
155891113e
commit
a2d3266648
@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||
from ._utils import GeneralTensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
@ -15,8 +17,13 @@ def register_elementwise_op(op):
|
||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
|
||||
if 'inplace' in kwargs:
|
||||
# TODO(jiaruifang) inplace will cause bugs
|
||||
input_tensor = input_tensor.clone()
|
||||
return op(input_tensor, *args, **kwargs)
|
||||
else:
|
||||
output = op(input_tensor, *args, **kwargs)
|
||||
# return output
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
|
Loading…
Reference in New Issue
Block a user