mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-30 17:22:21 +00:00
[hotfix] make Gemini work for conv DNN (#1998)
This commit is contained in:
parent
155891113e
commit
a2d3266648
@ -1,9 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||||
from ._utils import GeneralTensor
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
|
|
||||||
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
|
|
||||||
|
|
||||||
def register_elementwise_op(op):
|
def register_elementwise_op(op):
|
||||||
@ -15,16 +17,21 @@ def register_elementwise_op(op):
|
|||||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||||
This method computes on either a normal tensor or a sharded tensor.
|
This method computes on either a normal tensor or a sharded tensor.
|
||||||
"""
|
"""
|
||||||
|
if 'inplace' in kwargs:
|
||||||
output = op(input_tensor, *args, **kwargs)
|
# TODO(jiaruifang) inplace will cause bugs
|
||||||
if isinstance(input_tensor, ColoTensor):
|
input_tensor = input_tensor.clone()
|
||||||
if isinstance(output, str):
|
return op(input_tensor, *args, **kwargs)
|
||||||
return output
|
else:
|
||||||
if not isinstance(output, torch.Tensor):
|
output = op(input_tensor, *args, **kwargs)
|
||||||
raise NotImplementedError
|
# return output
|
||||||
return ColoTensor.from_torch_tensor(output,
|
if isinstance(input_tensor, ColoTensor):
|
||||||
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
if isinstance(output, str):
|
||||||
dist_attr=input_tensor.dist_spec))
|
return output
|
||||||
|
if not isinstance(output, torch.Tensor):
|
||||||
|
raise NotImplementedError
|
||||||
|
return ColoTensor.from_torch_tensor(output,
|
||||||
|
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
||||||
|
dist_attr=input_tensor.dist_spec))
|
||||||
|
|
||||||
|
|
||||||
# Tensor op
|
# Tensor op
|
||||||
|
Loading…
Reference in New Issue
Block a user