[autockpt] make it work. (#2257)

This commit is contained in:
Super Daniel
2023-01-02 23:37:45 +08:00
committed by GitHub
parent ac3739930d
commit 3ccf58aa76
4 changed files with 12 additions and 12 deletions

View File

@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(NodeHandler):
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.

View File

@@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
@@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
class ReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""

View File

@@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
@@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler):
class UnaryElementwiseHandler(MetaInfoNodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""