diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 4e0d46104..4f72ca4e0 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -4,6 +4,8 @@ import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +__all__ = ['ConvHandler'] + class ConvHandler(OperatorHandler): """ diff --git a/colossalai/auto_parallel/solver/dot_handler.py b/colossalai/auto_parallel/solver/dot_handler.py index 62d90c570..3ce2fedbd 100644 --- a/colossalai/auto_parallel/solver/dot_handler.py +++ b/colossalai/auto_parallel/solver/dot_handler.py @@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, from .operator_handler import OperatorHandler from functools import reduce +__all__ = ['DotHandler'] + class DotHandler(OperatorHandler): """ diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 26dcfd892..63d17e6cb 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -1,3 +1,4 @@ +from webbrowser import Opera import torch import torch.nn as nn from abc import ABC, abstractmethod @@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec from .sharding_strategy import StrategiesVector +__all__ = ['OperatorHandler'] + class OperatorHandler(ABC): ''' @@ -48,6 +51,9 @@ class OperatorHandler(ABC): @abstractmethod def register_strategy(self) -> StrategiesVector: + """ + Register + """ pass def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: