mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[autoparallel] handled illegal sharding strategy (#1728)
* [autoparallel] handled illegal sharding strategy * polish code
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
__all__ = ['exception_handler']
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
__all__ = ['ignore_sharding_exception']
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
def ignore_sharding_exception(func):
|
||||
"""
|
||||
A function wrapper to handle the AssertionError in the function.
|
||||
A function wrapper to handle the ShardingSpecException in the function.
|
||||
If ShardingSpecException occurs, this function will return None.
|
||||
|
||||
Usage:
|
||||
# mute the assertion error in the function
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def do_something():
|
||||
...
|
||||
"""
|
||||
@@ -18,9 +21,11 @@ def exception_handler(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
logger = get_dist_logger()
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
except ShardingSpecException as e:
|
||||
logger.debug(e)
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user