From d39e11dffb486f42720c229419269df7a86058ca Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 24 Aug 2022 15:44:07 +0800 Subject: [PATCH] [autoparallel] added namespace constraints (#1490) --- colossalai/auto_parallel/solver/conv_handler.py | 2 ++ colossalai/auto_parallel/solver/dot_handler.py | 2 ++ colossalai/auto_parallel/solver/operator_handler.py | 6 ++++++ 3 files changed, 10 insertions(+) diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 4e0d46104..4f72ca4e0 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -4,6 +4,8 @@ import torch from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler +__all__ = ['ConvHandler'] + class ConvHandler(OperatorHandler): """ diff --git a/colossalai/auto_parallel/solver/dot_handler.py b/colossalai/auto_parallel/solver/dot_handler.py index 62d90c570..3ce2fedbd 100644 --- a/colossalai/auto_parallel/solver/dot_handler.py +++ b/colossalai/auto_parallel/solver/dot_handler.py @@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, from .operator_handler import OperatorHandler from functools import reduce +__all__ = ['DotHandler'] + class DotHandler(OperatorHandler): """ diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 26dcfd892..63d17e6cb 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -1,3 +1,4 @@ +from webbrowser import Opera import torch import torch.nn as nn from abc import ABC, abstractmethod @@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec from .sharding_strategy import StrategiesVector +__all__ = ['OperatorHandler'] + class OperatorHandler(ABC): ''' @@ -48,6 +51,9 @@ class OperatorHandler(ABC): @abstractmethod def register_strategy(self) -> StrategiesVector: + """ + Register + """ pass def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: