[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
This commit is contained in:
Frank Lee
2022-10-19 12:53:06 +08:00
committed by GitHub
parent cbe9a4cb45
commit eee84908d4
36 changed files with 459 additions and 303 deletions

View File

@@ -1,16 +1,19 @@
import functools
import warnings
__all__ = ['exception_handler']
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpecException
__all__ = ['ignore_sharding_exception']
def exception_handler(func):
def ignore_sharding_exception(func):
"""
A function wrapper to handle the AssertionError in the function.
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
# mute the assertion error in the function
@exception_handler
@ignore_sharding_exception
def do_something():
...
"""
@@ -18,9 +21,11 @@ def exception_handler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
logger = get_dist_logger()
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
except ShardingSpecException as e:
logger.debug(e)
return None
return wrapper