mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[pipelinable]use pipelinable context to initialize non-pipeline model (#816)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]add module lazy init feature to support large model initization.
* [pipeline]add to_layer_list and partition method to support arbitrary non-pp model
* refactor the module structure
* polish
* [pipelinable]add unit test for pipelinable
* polish
* polish
* Fix CodeFactor issues.
This commit is contained in:
@@ -9,6 +9,28 @@ def _substitute_init_recursively(cls, func):
|
||||
func(subcls)
|
||||
|
||||
|
||||
def call_to_str(base, *args, **kwargs):
|
||||
"""Construct a string representation of a call.
|
||||
|
||||
Args:
|
||||
base (str): name of the call
|
||||
args (tuple, optional): args to ``base``
|
||||
kwargs (dict, optional): kwargs supplied to ``base``
|
||||
|
||||
Returns:
|
||||
str: A string representation of base(*args, **kwargs)
|
||||
"""
|
||||
name = f'{base}('
|
||||
if args:
|
||||
name += ', '.join(repr(arg) for arg in args)
|
||||
if kwargs:
|
||||
name += ', '
|
||||
if kwargs:
|
||||
name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
|
||||
name += ')'
|
||||
return name
|
||||
|
||||
|
||||
class InsertPostInitMethodToModuleSubClasses(object):
|
||||
|
||||
def __init__(self, default_dtype: Optional[torch.dtype] = None):
|
||||
@@ -28,7 +50,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
@functools.wraps(f)
|
||||
def wrapper(module: torch.nn.Module, *args, **kwargs):
|
||||
f(module, *args, **kwargs)
|
||||
self._post_init_method(module)
|
||||
self._post_init_method(module, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -71,7 +93,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
return False
|
||||
|
||||
# To be implemented by inheriting classes
|
||||
def _post_init_method(self, module):
|
||||
def _post_init_method(self, module, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _pre_context_exec(self):
|
||||
|
Reference in New Issue
Block a user