mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[autockpt] make it work. (#2257)
This commit is contained in:
@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
||||
|
||||
|
||||
@operator_registry.register(BCAST_FUNC_OP)
|
||||
class BinaryElementwiseHandler(NodeHandler):
|
||||
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||
operands and broadcasting occurs such as torch.add.
|
||||
|
@@ -3,7 +3,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ReshapeGenerator, StrategyGenerator
|
||||
|
||||
@@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.unsqueeze)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
class ReshapeHandler(NodeHandler):
|
||||
class ReshapeHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||
"""
|
||||
|
@@ -3,7 +3,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
||||
|
||||
@@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
|
||||
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
||||
@operator_registry.register(torch.Tensor.contiguous)
|
||||
@operator_registry.register(torch.nn.functional.dropout)
|
||||
class UnaryElementwiseHandler(NodeHandler):
|
||||
class UnaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user